blob: ab3919df90a6b3a20af4b4f9f145c75523dbd9bd [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Luke Hutton261b7b62023-01-10 14:50:31 +00002// Copyright (c) 2020-2023, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "tensor_ops.h"
17#include "quant_util.h"
18#include "template_types.h"
James Ward8b390432022-08-12 20:48:56 +010019#include "half.hpp"
Eric Kunzee5e26762020-10-13 16:11:07 -070020
21using namespace TosaReference;
22using namespace Eigen;
23using namespace tosa;
24
Kevin Cheng9fe17242021-11-10 01:04:39 +000025int check_pool2d_attribute(tosa::TosaPoolAttribute* attribute,
26 std::vector<int32_t> input_shape,
27 std::vector<int32_t> output_shape,
28 std::string& msg)
Kevin Cheng7eb93d72021-10-09 01:26:08 +000029{
TatWai Chong86c403b2022-06-06 20:46:01 -070030 if (attribute->pad().size() != 4)
Kevin Cheng7eb93d72021-10-09 01:26:08 +000031 {
32 msg = "illegal size for attribute padding";
33 return 1;
34 }
35
36 if (attribute->kernel().size() != 2)
37 {
38 msg = "illegal size for attribute kernel";
39 return 1;
40 }
41
42 if (attribute->stride().size() != 2)
43 {
44 msg = "illegal size for attribute stride";
45 return 1;
46 }
47
TatWai Chong86c403b2022-06-06 20:46:01 -070048 for (int32_t i : attribute->pad())
Kevin Cheng7eb93d72021-10-09 01:26:08 +000049 {
50 if (i < 0)
51 {
52 msg = "At least one pad is smaller than zero";
53 return 1;
54 }
55 }
56
57 for (int32_t i : attribute->kernel())
58 {
59 if (i < 1)
60 {
Kevin Cheng9fe17242021-11-10 01:04:39 +000061 msg = "At least one kernel dimension is smaller than one";
Kevin Cheng7eb93d72021-10-09 01:26:08 +000062 return 1;
63 }
64 }
65
66 for (int32_t i : attribute->stride())
67 {
68 if (i < 1)
69 {
Kevin Cheng9fe17242021-11-10 01:04:39 +000070 msg = "At least one stride dimension is smaller than one";
Kevin Cheng7eb93d72021-10-09 01:26:08 +000071 return 1;
72 }
73 }
74
75 int32_t IH = input_shape[1];
76 int32_t IW = input_shape[2];
77 int32_t OH = output_shape[1];
78 int32_t OW = output_shape[2];
79
TatWai Chong86c403b2022-06-06 20:46:01 -070080 int32_t pad_top = attribute->pad()[0];
81 int32_t pad_bottom = attribute->pad()[1];
82 int32_t pad_left = attribute->pad()[2];
83 int32_t pad_right = attribute->pad()[3];
Kevin Cheng7eb93d72021-10-09 01:26:08 +000084
85 int32_t stride_y = attribute->stride()[0];
86 int32_t stride_x = attribute->stride()[1];
87 int32_t kernel_y = attribute->kernel()[0];
88 int32_t kernel_x = attribute->kernel()[1];
89
90 if (pad_top >= kernel_y || pad_bottom >= kernel_y || pad_left >= kernel_x || pad_right >= kernel_x)
91 {
92 msg = "At least one pad is >= kernel dimension";
93 return 1;
94 }
95
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +010096 int32_t full_H = IH + pad_top + pad_bottom - kernel_y;
97 int32_t full_W = IW + pad_left + pad_right - kernel_x;
98
99 if ((full_H % stride_y != 0) ||
100 (full_W % stride_x != 0))
Kevin Cheng7eb93d72021-10-09 01:26:08 +0000101 {
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100102 msg = "Parameters must yield exact integer output dimensions";
103 return 1;
104 }
105
106 if ((OH != (full_H / stride_y) + 1) ||
107 (OW != (full_W / stride_x) + 1))
108 {
109 msg = "Mismatch between output shape provided and expected output shape (" +
110 std::to_string((full_H / stride_y) + 1) + "," +
111 std::to_string((full_W / stride_x) + 1) + ")";
Kevin Cheng7eb93d72021-10-09 01:26:08 +0000112 return 1;
113 }
114
115 return 0;
116}
117
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000118int check_conv_attribute(tosa::TosaConvAttribute* attribute,
Kevin Cheng9fe17242021-11-10 01:04:39 +0000119 uint32_t conv_dimension,
120 std::vector<int32_t> input_shape,
121 std::vector<int32_t> output_shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100122 std::vector<int32_t> weights,
123 uint32_t offset_kernel,
Kevin Cheng9fe17242021-11-10 01:04:39 +0000124 DType InDtype,
125 DType WeightDtype,
126 std::string& msg)
127{
TatWai Chong86c403b2022-06-06 20:46:01 -0700128 if (attribute->pad().size() != (2 * conv_dimension))
Kevin Cheng9fe17242021-11-10 01:04:39 +0000129 {
TatWai Chong86c403b2022-06-06 20:46:01 -0700130 msg = "Illegal size for attribute pad";
Kevin Cheng9fe17242021-11-10 01:04:39 +0000131 return 1;
132 }
133
134 if (attribute->stride().size() != conv_dimension)
135 {
136 msg = "Illegal size for attribute stride";
137 return 1;
138 }
139
140 if (attribute->dilation().size() != conv_dimension)
141 {
142 msg = "Illegal size for attribute dilation";
143 return 1;
144 }
145
TatWai Chong86c403b2022-06-06 20:46:01 -0700146 for (int32_t i : attribute->pad())
Kevin Cheng9fe17242021-11-10 01:04:39 +0000147 {
148 if (i < 0)
149 {
150 msg = "At least one pad is smaller than zero";
151 return 1;
152 }
153 }
154
155 for (int32_t i : attribute->stride())
156 {
157 if (i < 1)
158 {
159 msg = "At least one stride dimension is smaller than one";
160 return 1;
161 }
162 }
163
164 for (int32_t i : attribute->dilation())
165 {
166 if (i < 1)
167 {
168 msg = "At least one dilation dimension is smaller than one";
169 return 1;
170 }
171 }
172
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100173 ASSERT_MSG(conv_dimension == 2 || conv_dimension == 3, "Unsupported convolution dimension")
174
TatWai Chongfd629052022-07-25 04:01:58 +0000175 int32_t offset_d = conv_dimension == 3 ? 1 : 0;
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100176 int32_t ID = conv_dimension == 3 ? input_shape[1] : 1;
177 int32_t IH = input_shape[1 + offset_d];
178 int32_t IW = input_shape[2 + offset_d];
179 int32_t OD = conv_dimension == 3 ? output_shape[1] : 1;
180 int32_t OH = output_shape[1 + offset_d];
181 int32_t OW = output_shape[2 + offset_d];
182
183 int32_t stride_d = conv_dimension == 3 ? attribute->stride()[0] : 1;
184 int32_t stride_y = attribute->stride()[0 + offset_d];
185 int32_t stride_x = attribute->stride()[1 + offset_d];
186 int32_t kernel_d = conv_dimension == 3 ? weights[offset_kernel] : 1;
187 int32_t kernel_h = weights[offset_kernel + offset_d];
188 int32_t kernel_w = weights[offset_kernel + 1 + offset_d];
189 int32_t dilation_d = conv_dimension == 3 ? attribute->dilation()[0] : 1;
190 int32_t dilation_y = attribute->dilation()[0 + offset_d];
191 int32_t dilation_x = attribute->dilation()[1 + offset_d];
192
193 offset_d *= 2;
TatWai Chong86c403b2022-06-06 20:46:01 -0700194 int32_t pad_d0 = conv_dimension == 3 ? attribute->pad()[0] : 0;
195 int32_t pad_d1 = conv_dimension == 3 ? attribute->pad()[1] : 0;
196 int32_t pad_top = attribute->pad()[0 + offset_d];
197 int32_t pad_bottom = attribute->pad()[1 + offset_d];
198 int32_t pad_left = attribute->pad()[2 + offset_d];
199 int32_t pad_right = attribute->pad()[3 + offset_d];
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100200
201 int32_t full_D = ID - 1 + pad_d0 + pad_d1 - (kernel_d - 1) * dilation_d;
202 int32_t full_H = IH - 1 + pad_top + pad_bottom - (kernel_h - 1) * dilation_y;
203 int32_t full_W = IW - 1 + pad_left + pad_right - (kernel_w - 1) * dilation_x;
204
205 if ((full_H % stride_y != 0) ||
206 (full_W % stride_x != 0) ||
207 (full_D % stride_d != 0))
208 {
209 msg = "Parameters must yield exact integer output dimensions";
210 return 1;
211 }
212
213 if ((OH != (full_H / stride_y) + 1) ||
214 (OW != (full_W / stride_x) + 1) ||
215 (OD != (full_D / stride_d) + 1))
216 {
217 std::string msg_d = "";
218 if (conv_dimension == 3)
219 {
220 msg_d += std::to_string((full_D / stride_d) + 1) + ",";
221 }
222 msg = "Mismatch between output shape provided and expected output shape (" +
223 msg_d +
224 std::to_string((full_H / stride_y) + 1) + "," +
225 std::to_string((full_W / stride_x) + 1) + ")";
226 return 1;
227 }
228
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000229 if (InDtype != DType_INT8 && attribute->input_zp() != 0) {
230 msg = "Input zero point must be zero for non-int8 data";
231 return 1;
232 }
233 if (WeightDtype != DType_INT8 && attribute->weight_zp() != 0) {
234 msg = "Weight zero point must be zero for non-int8 data";
235 return 1;
Kevin Cheng9fe17242021-11-10 01:04:39 +0000236 }
237
238 return 0;
239}
240
Luke Hutton57287132023-02-06 14:54:18 +0000241int check_fft_shape(const std::vector<int32_t>& in_real,
242 const std::vector<int32_t>& in_imag,
243 const std::vector<int32_t>& out_real,
244 const std::vector<int32_t>& out_imag,
245 std::string& msg) {
246 const bool is_rfft = in_imag.empty();
247 auto is_power_of_two = [](int32_t n) -> bool
248 {
249 return (n & (n-1)) == 0 && n > 0;
250 };
251
252 if (!is_power_of_two(in_real[1]) || !is_power_of_two(in_real[2]))
253 {
254 msg = "Input height and width must be a power of two";
255 return 1;
256 }
257
258 // RFFT does not have a second input
259 if (!is_rfft)
260 {
261 bool input_check = true;
262 for (size_t i = 0; i < in_real.size(); i++)
263 {
264 if (in_real[i] != in_imag[i])
265 {
266 input_check = false;
267 break;
268 }
269 }
270 if (!input_check)
271 {
272 msg = "Mismatch between real input shape and imaginary input shape";
273 return 1;
274 }
275 }
276
277 bool output_check = true;
278 for (size_t i = 0; i < out_real.size(); i++)
279 {
280 if (out_real[i] != out_imag[i])
281 {
282 output_check = false;
283 break;
284 }
285 }
286 if (!output_check)
287 {
288 msg = "Mismatch between real output shape and imaginary output shape";
289 return 1;
290 }
291
292 if (in_real[0] != out_real[0])
293 {
294 msg = "Input and output batch size don't match";
295 return 1;
296 }
297 if (in_real[1] != out_real[1])
298 {
299 msg = "Input and output height don't match";
300 return 1;
301 }
302
303 if (is_rfft)
304 {
305 if (in_real[2] / 2 + 1 != out_real[2])
306 {
307 msg = "Output width is expected to match input width / 2 + 1";
308 return 1;
309 }
310 } else {
311 if (in_real[2] != out_real[2])
312 {
313 msg = "Input and output width don't match";
314 return 1;
315 }
316 }
317
318 return 0;
319}
320
Eric Kunzee5e26762020-10-13 16:11:07 -0700321template <int Rank, DType Dtype>
Kevin Chengacb550f2021-06-29 15:32:19 -0700322OpArgMax<Rank, Dtype>::OpArgMax(SubgraphTraverser* sgt_,
323 TosaAttributeBase* attribute_,
Kevin Chengacb550f2021-06-29 15:32:19 -0700324 uint64_t id_)
325 : GraphNode(sgt_, Op_ARGMAX, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700326{
327 setRequiredOperands(1, 1);
Kevin Chengcc61be32021-10-14 17:09:57 -0700328 setRequiredRank(1, 4);
Eric Kunzee5e26762020-10-13 16:11:07 -0700329
330 INIT_ATTRIBUTE(Axis);
331}
332
333template <int Rank, DType Dtype>
334OpArgMax<Rank, Dtype>::~OpArgMax()
335{
336 if (attribute)
337 delete attribute;
338}
339
340template <int Rank, DType Dtype>
341int OpArgMax<Rank, Dtype>::checkTensorAttributes()
342{
343 if (validateRequiredOperands())
344 return 1;
345
Kevin Chengcc61be32021-10-14 17:09:57 -0700346 if (validateRequiredRank(inputs[0]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700347 {
348 return 1;
349 }
350
Kevin Chengcc61be32021-10-14 17:09:57 -0700351 int32_t output_rank = inputs[0]->getRank() - 1;
352 if (output_rank != outputs[0]->getRank())
353 {
354 printNodeValidationError("OpArgMax: Output rank needs to be rank(input) - 1");
355 return 1;
356 }
357
358 if (outputs[0]->getDtype() != DType_INT32)
359 {
360 printNodeValidationError("OpArgMax: Output data type not supported for this configuration of operator");
361 return 1;
362 }
363
Eric Kunzee5e26762020-10-13 16:11:07 -0700364 input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
365 output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
366
Kevin Chengcc61be32021-10-14 17:09:57 -0700367 if (attribute->axis() < 0 || attribute->axis() >= input->getRank())
368 {
369 printNodeValidationError("OpArgMax: Axis needs to be within [0, rank(input)]");
370 return 1;
371 }
372
373 bool shape_check = true;
374 for (int32_t i = 0; i < input->getRank(); i++)
375 {
376 if (i < attribute->axis())
377 {
378 if (input->getShape()[i] != output->getShape()[i])
379 {
380 shape_check = false;
381 break;
382 }
383 }
384 else if (i > attribute->axis())
385 {
386 if (input->getShape()[i] != output->getShape()[i - 1])
387 {
388 shape_check = false;
389 break;
390 }
391 }
392 // No need to check i == axis
393 }
394 if (!shape_check)
395 {
396 printNodeValidationError("OpArgMax: Mismatch between output shape provided and expected output shape");
397 return 1;
398 }
399
Eric Kunzee5e26762020-10-13 16:11:07 -0700400 return 0;
401}
402
403template <int Rank, DType Dtype>
404int OpArgMax<Rank, Dtype>::eval()
405{
406 Eigen::Tensor<DenseIndex, Rank - 1> index = this->input->getTensor().argmax(attribute->axis());
407
408 this->output->getTensor() = index.unaryExpr([](DenseIndex in) -> OutEigenType { return (OutEigenType)in; });
409
410 return GraphNode::eval();
411}
412
James Ward8b390432022-08-12 20:48:56 +0100413template <DType Dtype, DType AccDtype>
414OpAvgPool2d<Dtype, AccDtype>::OpAvgPool2d(SubgraphTraverser* sgt_,
Kevin Chengacb550f2021-06-29 15:32:19 -0700415 TosaAttributeBase* attribute_,
Kevin Chengacb550f2021-06-29 15:32:19 -0700416 uint64_t id_)
417 : GraphNode(sgt_, Op_AVG_POOL2D, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700418{
419 setRequiredOperands(1, 1);
420 setRequiredRank(4);
421
Kevin Cheng93a16282021-08-31 16:14:03 -0700422 INIT_ATTRIBUTE(Pool);
Eric Kunzee5e26762020-10-13 16:11:07 -0700423}
424
James Ward8b390432022-08-12 20:48:56 +0100425template <DType Dtype, DType AccDtype>
426OpAvgPool2d<Dtype, AccDtype>::~OpAvgPool2d()
Eric Kunzee5e26762020-10-13 16:11:07 -0700427{
428 if (attribute)
429 delete attribute;
430}
431
James Ward8b390432022-08-12 20:48:56 +0100432template <DType Dtype, DType AccDtype>
433int OpAvgPool2d<Dtype, AccDtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -0700434{
435 if (validateRequiredOperands())
436 return 1;
437
438 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
439 {
440 return 1;
441 }
442
443 if (inputs[0]->matchType(*outputs[0]))
444 {
445 printNodeValidationError("OpAvgPool2d: input and output tensor type mismatch");
446 return 1;
447 }
448
449 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
450 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
451
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000452 ERROR_IF(Dtype != DType_INT8 && attribute->input_zp() != 0, "OpAvgPool2d: Input zeropoint must be zero for non int8_t data");
453 ERROR_IF(Dtype != DType_INT8 && attribute->output_zp() != 0, "OpAvgPool2d: Output zeropoint must be zero for non int8_t data");
Eric Kunzee5e26762020-10-13 16:11:07 -0700454
Kevin Cheng7eb93d72021-10-09 01:26:08 +0000455 std::string msg;
Kevin Cheng9fe17242021-11-10 01:04:39 +0000456 if (check_pool2d_attribute(attribute, in->getShape(), out->getShape(), msg))
Eric Kunzee5e26762020-10-13 16:11:07 -0700457 {
Kevin Cheng7eb93d72021-10-09 01:26:08 +0000458 msg = "OpAvgPool2d: " + msg;
459 printNodeValidationError(msg.c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -0700460 return 1;
461 }
462
463 return 0;
464}
465
Eric Kunze830add42022-01-25 22:56:46 -0800466// This calculates the number of padding elements used for each location along an axis
467// Average pooling only divides by the number of elements used, not including padding.
468// This function uses left/right, but is also used for vertical padding with top/bottom
James Ward8b390432022-08-12 20:48:56 +0100469template <DType Dtype, DType AccDtype>
470ETensor1<int32_t> OpAvgPool2d<Dtype, AccDtype>::calculate_div_map_1d(int in_size, int out_size, int kernel_size, int stride, int32_t pad_left, int32_t pad_right)
Eric Kunzee5e26762020-10-13 16:11:07 -0700471{
472 ETensor1<int32_t> result(out_size);
473
Eric Kunzee5e26762020-10-13 16:11:07 -0700474 result.setConstant(kernel_size);
475
Eric Kunze830add42022-01-25 22:56:46 -0800476 // adjust divisors on the left side for padding
477 // We start at the leftmost output element, and remove pad_left - (index * stride) elements
478 // until we have no more padding being used
Eric Kunze67a91552022-02-02 11:27:21 -0800479 for(int index = 0; (index <= pad_left / stride) && (index < out_size); index++) {
Eric Kunze830add42022-01-25 22:56:46 -0800480 int32_t adjust = pad_left - (index * stride);
481 result(index) -= adjust;
Eric Kunzee5e26762020-10-13 16:11:07 -0700482 }
483
Eric Kunze830add42022-01-25 22:56:46 -0800484 // The process repeats on the right side. Padding starts taking effect as we
485 // near the rightmost input element. The first output element which touches
486 // padding is defined in the initialization of index below. Then we keep moving
487 // to the right, increasing padding until we get to the last output element.
488 int index = std::max(0, ((pad_left + in_size - kernel_size) / stride) + 1);
489 for (; index < out_size; index++) {
490 int32_t adjust = ((index * stride) + kernel_size) - (pad_left + in_size);
491 result(index) -= adjust;
Eric Kunzee5e26762020-10-13 16:11:07 -0700492 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700493 return result;
494}
495
496// assuming input and output tensor have same scales like tflite reference
497// so no need to scale input and output
James Ward8b390432022-08-12 20:48:56 +0100498template <DType Dtype, DType AccDtype>
499int OpAvgPool2d<Dtype, AccDtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -0700500{
501 int in_batch = this->in->getShape()[0];
502 int in_height = this->in->getShape()[1];
503 int in_width = this->in->getShape()[2];
504 int in_channels = this->in->getShape()[3];
505
506 int out_batch = this->out->getShape()[0];
507 int out_height = this->out->getShape()[1];
508 int out_width = this->out->getShape()[2];
509 int out_channels = this->out->getShape()[3];
510
Kevin Chengacb550f2021-06-29 15:32:19 -0700511 ERROR_IF(in_batch != out_batch, "OpAvgPool2d: tensor batch mismatch %d != %d", in_batch, out_batch);
512 ERROR_IF(in_channels != out_channels, "OpAvgPool2d: tensor channel mismatch %d != %d", in_channels, out_channels);
Eric Kunzee5e26762020-10-13 16:11:07 -0700513
TatWai Chong86c403b2022-06-06 20:46:01 -0700514 int pad_top = this->attribute->pad()[0];
515 int pad_bottom = this->attribute->pad()[1];
516 int pad_left = this->attribute->pad()[2];
517 int pad_right = this->attribute->pad()[3];
Jerry Gea793f462023-04-11 00:05:02 +0000518 int kernel_y = this->attribute->kernel()[0];
519 int kernel_x = this->attribute->kernel()[1];
520 int stride_y = this->attribute->stride()[0];
521 int stride_x = this->attribute->stride()[1];
522
523 // Check Tosa Level
524 auto tosa_level = g_func_config.tosa_level;
525 LEVEL_CHECK(kernel_y <= tosa_level.MAX_KERNEL, "kernel_y should be smaller than or equal to MAX_KERNEL");
526 LEVEL_CHECK(kernel_x <= tosa_level.MAX_KERNEL, "kernel_x should be smaller than or equal to MAX_KERNEL");
527 LEVEL_CHECK(stride_y <= tosa_level.MAX_STRIDE, "stride_y should be smaller than or equal to MAX_STRIDE");
528 LEVEL_CHECK(stride_x <= tosa_level.MAX_STRIDE, "stride_x should be smaller than or equal to MAX_STRIDE");
529 LEVEL_CHECK(pad_top <= tosa_level.MAX_KERNEL, "pad_top should be smaller than or equal to MAX_KERNEL");
530 LEVEL_CHECK(pad_bottom <= tosa_level.MAX_KERNEL, "pad_bottom should be smaller than or equal to MAX_KERNEL");
531 LEVEL_CHECK(pad_left <= tosa_level.MAX_KERNEL, "pad_left should be smaller than or equal to MAX_KERNEL");
532 LEVEL_CHECK(pad_right <= tosa_level.MAX_KERNEL, "pad_right should be smaller than or equal to MAX_KERNEL");
Eric Kunzee5e26762020-10-13 16:11:07 -0700533
James Ward8b390432022-08-12 20:48:56 +0100534 tosa::DType accum_dtype = (tosa::DType)this->attribute->accum_dtype();
535
Eric Kunzee5e26762020-10-13 16:11:07 -0700536 DEBUG_INFO(OP,
537 "perform AvgPool2d, input.shape=[%d,%d,%d,%d], output.shape=[%d,%d,%d,%d], kernel=[%d,%d], "
James Ward8b390432022-08-12 20:48:56 +0100538 "stride=[%d,%d], pad=[%d,%d,%d,%d], accum_dtype=%s",
Jerry Gea793f462023-04-11 00:05:02 +0000539 in_batch, in_height, in_width, in_channels, out_batch, out_height, out_width, out_channels, kernel_y,
540 kernel_x, stride_y, stride_x, pad_top, pad_bottom, pad_left, pad_right, EnumNamesDType()[accum_dtype]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700541
542 Eigen::array<Eigen::Index, 2> im2col_input_dims;
Jerry Gea793f462023-04-11 00:05:02 +0000543 im2col_input_dims[0] = kernel_y * kernel_x;
Eric Kunzee5e26762020-10-13 16:11:07 -0700544 im2col_input_dims[1] = out_batch * out_height * out_width * out_channels;
545
546 Eigen::array<Eigen::Index, 4> col2im_output_dims;
547 col2im_output_dims[0] = out_batch;
548 col2im_output_dims[1] = out_height;
549 col2im_output_dims[2] = out_width;
550 col2im_output_dims[3] = out_channels;
551
TatWai Chong86c403b2022-06-06 20:46:01 -0700552 Eigen::array<std::pair<int32_t, int32_t>, 4> pad;
553 pad[0] = std::make_pair(0, 0);
554 pad[1] = std::make_pair(pad_top, pad_bottom);
555 pad[2] = std::make_pair(pad_left, pad_right);
556 pad[3] = std::make_pair(0, 0);
Eric Kunzee5e26762020-10-13 16:11:07 -0700557
558 ETensor4<InEigenType> input_val = this->in->getTensor();
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000559 if (Dtype == DType_INT8)
Eric Kunzee5e26762020-10-13 16:11:07 -0700560 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000561 input_val = input_val - (InEigenType)attribute->input_zp();
Eric Kunzee5e26762020-10-13 16:11:07 -0700562 }
563
TatWai Chong86c403b2022-06-06 20:46:01 -0700564 ETensor4<InEigenType> input_padded = input_val.pad(pad);
Eric Kunzee5e26762020-10-13 16:11:07 -0700565
566 // assuming input and output have same scales
567 // so input and output scaling is not required
568 // TODO: check if this assumption TOSA made
569
570 // extract_image_patches() output [N, KH, KW, H * W, C]
571 // transpose to [KH, KW, N, H * W, C]
572 // reshape to [KH * KW, N * H * W * C]
573 ETensor2<InEigenType> input_extract_patches =
Jerry Gea793f462023-04-11 00:05:02 +0000574 input_padded.extract_image_patches(kernel_y, kernel_x, stride_y, stride_x, 1, 1, Eigen::PADDING_VALID)
Eric Kunzee5e26762020-10-13 16:11:07 -0700575 .shuffle(Eigen::array<Eigen::Index, 5>{ 1, 2, 0, 3, 4 })
576 .reshape(im2col_input_dims);
577
578 // 1D result with [N * H * W * C]
579 ETensor1<AccEigenType> out_1d(this->out->getElementCount());
580 out_1d.setZero();
581
582 // sum pool
583 for (size_t i = 0; i < this->out->getElementCount(); i++)
584 {
Jerry Gea793f462023-04-11 00:05:02 +0000585 for (int32_t j = 0; j < kernel_y * kernel_x; j++)
Eric Kunzee5e26762020-10-13 16:11:07 -0700586 {
587 out_1d(i) += (AccEigenType)input_extract_patches(j, i);
588 }
589 }
590
591 // reshape result to [N, H, W, C] and divide with div_map
592 ETensor4<AccEigenType> sum = out_1d.reshape(col2im_output_dims);
593
594 // calculate 1d height/width div_map (number of elements this pooling window covers)
595 // and outer product to get 2d div_map, then reshape/broadcast to [N, H, W, C]
Jerry Gea793f462023-04-11 00:05:02 +0000596 ETensor1<int32_t> div_map_h = calculate_div_map_1d(in_height, out_height, kernel_y, stride_x, pad_top, pad_bottom);
597 ETensor1<int32_t> div_map_w = calculate_div_map_1d(in_width, out_width, kernel_x, stride_x, pad_left, pad_right);
Eric Kunzee5e26762020-10-13 16:11:07 -0700598 Eigen::array<Eigen::IndexPair<Eigen::Index>, 1> contract_dims = { Eigen::IndexPair<Eigen::Index>(1, 0) };
599 Eigen::array<Eigen::Index, 4> bcast{ out_batch, 1, 1, out_channels };
600
James Ward24dbc422022-10-19 12:20:31 +0100601 ETensor2<int32_t> dm2_w = div_map_w.reshape(Eigen::array<Eigen::Index, 2>{ 1, out_width });
602 ETensor2<int32_t> dm2_h = div_map_h.reshape(Eigen::array<Eigen::Index, 2>{ out_height, 1 });
Eric Kunzee5e26762020-10-13 16:11:07 -0700603 ETensor4<int32_t> div_map =
James Ward24dbc422022-10-19 12:20:31 +0100604 dm2_h.contract(dm2_w, contract_dims)
Eric Kunzee5e26762020-10-13 16:11:07 -0700605 .reshape(Eigen::array<Eigen::Index, 4>{ 1, out_height, out_width, 1 })
606 .broadcast(bcast);
James Ward24dbc422022-10-19 12:20:31 +0100607 if (Dtype != DType_FP32 && Dtype != DType_FP16 && Dtype != DType_BF16)
Eric Kunzee5e26762020-10-13 16:11:07 -0700608 {
Kevin Chengacb550f2021-06-29 15:32:19 -0700609 try
610 {
611 this->out->getTensor() = sum.binaryExpr(div_map, [](AccEigenType value, int32_t div) -> OutEigenType {
612 int32_t multiplier, shift;
613 TosaReference::QuantUtil::reciprocal_scale(div, multiplier, shift);
Eric Kunzee5e26762020-10-13 16:11:07 -0700614
Kevin Chengacb550f2021-06-29 15:32:19 -0700615 return (OutEigenType)TosaReference::QuantUtil::apply_scale_32(value, multiplier, shift, false);
616 });
617 }
618 catch (std::string desc)
619 {
620 REQUIRE(false, "OpAvgPool2d apply_scale_32() fails: %s.", desc.c_str());
621 }
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000622 this->out->getTensor() = this->out->getTensor() + (OutEigenType)(attribute->output_zp());
Eric Kunzee5e26762020-10-13 16:11:07 -0700623 this->out->getTensor() = this->out->getTensor().cwiseMax((OutEigenType)QMin);
624 this->out->getTensor() = this->out->getTensor().cwiseMin((OutEigenType)QMax);
625 }
626 else
627 {
James Ward24dbc422022-10-19 12:20:31 +0100628 // Case for float-types
Eric Kunzee5e26762020-10-13 16:11:07 -0700629 this->out->getTensor() = (sum / div_map.template cast<AccEigenType>()).template cast<OutEigenType>();
630 }
631
632 return GraphNode::eval();
633}
634
James Wardd34b3fc2023-01-18 14:51:25 +0000635template <DType InDtype, DType WeightDtype, DType OutDtype>
636OpConv2d<InDtype, WeightDtype, OutDtype>::OpConv2d(SubgraphTraverser* sgt_,
Kevin Chengacb550f2021-06-29 15:32:19 -0700637 TosaAttributeBase* attribute_,
Kevin Chengacb550f2021-06-29 15:32:19 -0700638 uint64_t id_)
639 : GraphNode(sgt_, Op_CONV2D, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700640{
641 setRequiredOperands(3, 1);
642 setRequiredRank(4);
643
Kevin Cheng93a16282021-08-31 16:14:03 -0700644 INIT_ATTRIBUTE(Conv);
Eric Kunzee5e26762020-10-13 16:11:07 -0700645}
646
James Wardd34b3fc2023-01-18 14:51:25 +0000647template <DType InDtype, DType WeightDtype, DType OutDtype>
648OpConv2d<InDtype, WeightDtype, OutDtype>::~OpConv2d()
Eric Kunzee5e26762020-10-13 16:11:07 -0700649{
650 if (attribute)
651 delete attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -0700652}
653
James Wardd34b3fc2023-01-18 14:51:25 +0000654template <DType InDtype, DType WeightDtype, DType OutDtype>
655int OpConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -0700656{
657 if (validateRequiredOperands())
658 return 1;
659
660 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
661 {
662 return 1;
663 }
664
665 // 'bias' checked separatedly since it doens't make sense to make required rank ranging from 1 to 4
666 if (inputs[2]->getRank() != 1)
667 {
668 printNodeValidationError("OpConv2d: bias tensor must be rank 1");
669 }
670
James Wardd34b3fc2023-01-18 14:51:25 +0000671 ERROR_IF(outputs[0]->getDtype() != OutDtype,
James Ward8b390432022-08-12 20:48:56 +0100672 "OpConv2d: Output data type not supported for this configuration of operator");
Kevin Chengcc61be32021-10-14 17:09:57 -0700673
Eric Kunzee5e26762020-10-13 16:11:07 -0700674 input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
675 weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
676 bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
James Ward8b390432022-08-12 20:48:56 +0100677 output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700678
Kevin Cheng9fe17242021-11-10 01:04:39 +0000679 std::string msg;
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000680 if (check_conv_attribute(attribute, 2 /* conv_dimension */, input->getShape(), output->getShape(),
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100681 weight->getShape(), 1 /* offset_kernel */, InDtype, WeightDtype, msg))
Eric Kunzee5e26762020-10-13 16:11:07 -0700682 {
Kevin Cheng9fe17242021-11-10 01:04:39 +0000683 msg = "OpConv2d: " + msg;
684 printNodeValidationError(msg.c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -0700685 return 1;
686 }
687
Eric Kunzee5e26762020-10-13 16:11:07 -0700688 return 0;
689}
690
James Wardd34b3fc2023-01-18 14:51:25 +0000691template <DType InDtype, DType WeightDtype, DType OutDtype>
692int OpConv2d<InDtype, WeightDtype, OutDtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -0700693{
694 int in_batch = this->input->getShape()[0];
695 int in_height = this->input->getShape()[1];
696 int in_width = this->input->getShape()[2];
697 int in_channels = this->input->getShape()[3];
698
699 int f_out_channels = this->weight->getShape()[0];
700 int f_height = this->weight->getShape()[1];
701 int f_width = this->weight->getShape()[2];
702 int f_in_channels = this->weight->getShape()[3];
703
704 int b_out_channels = this->bias->getShape()[0];
705
706 int out_batch = this->output->getShape()[0];
707 int out_height = this->output->getShape()[1];
708 int out_width = this->output->getShape()[2];
709 int out_channels = this->output->getShape()[3];
710
Kevin Chengacb550f2021-06-29 15:32:19 -0700711 ERROR_IF(in_batch != out_batch, "OpConv2d: tensor batch mismatch %d != %d", in_batch, out_batch);
712 ERROR_IF(f_in_channels != in_channels, "OpConv2d: tensor input channel mismatch %d != %d", f_in_channels,
713 in_channels);
714 ERROR_IF(f_out_channels != out_channels, "OpConv2d: tensor output channel mismatch %d != %d", f_out_channels,
715 out_channels);
716 ERROR_IF(b_out_channels != out_channels, "OpConv2d: bias channel mismatch %d != %d", b_out_channels, out_channels);
Eric Kunzee5e26762020-10-13 16:11:07 -0700717
TatWai Chong86c403b2022-06-06 20:46:01 -0700718 int pad_top = this->attribute->pad()[0];
719 int pad_bottom = this->attribute->pad()[1];
720 int pad_left = this->attribute->pad()[2];
721 int pad_right = this->attribute->pad()[3];
722
Jerry Gea793f462023-04-11 00:05:02 +0000723 int stride_y = this->attribute->stride()[0];
724 int stride_x = this->attribute->stride()[1];
725 int dilation_y = this->attribute->dilation()[0];
726 int dilation_x = this->attribute->dilation()[1];
727
728 // Check Tosa Level
729 auto tosa_level = g_func_config.tosa_level;
730 LEVEL_CHECK(dilation_y * f_height <= tosa_level.MAX_KERNEL, "dilation_y * KH should be smaller than or equal to MAX_KERNEL");
731 LEVEL_CHECK(dilation_x * f_width <= tosa_level.MAX_KERNEL, "dilation_x * KW should be smaller than or equal to MAX_KERNEL");
732 LEVEL_CHECK(pad_top <= tosa_level.MAX_KERNEL, "pad_top should be smaller than or equal to MAX_KERNEL");
733 LEVEL_CHECK(pad_bottom <= tosa_level.MAX_KERNEL, "pad_bottom should be smaller than or equal to MAX_KERNEL");
734 LEVEL_CHECK(pad_left <= tosa_level.MAX_KERNEL, "pad_left should be smaller than or equal to MAX_KERNEL");
735 LEVEL_CHECK(pad_right <= tosa_level.MAX_KERNEL, "pad_right should be smaller than or equal to MAX_KERNEL");
736 LEVEL_CHECK(stride_y <= tosa_level.MAX_STRIDE, "stride_y should be smaller than or equal to MAX_STRIDE");
737 LEVEL_CHECK(stride_x <= tosa_level.MAX_STRIDE, "stride_x should be smaller than or equal to MAX_STRIDE");
Eric Kunzee5e26762020-10-13 16:11:07 -0700738
739 DEBUG_INFO(OP,
740 "perform OpConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], output.shape=[%d,%d,%d,%d], "
James Wardd34b3fc2023-01-18 14:51:25 +0000741 "stride=[%d,%d], dilation=[%d,%d], pad=[%d,%d,%d,%d]",
Eric Kunzee5e26762020-10-13 16:11:07 -0700742 in_batch, in_height, in_width, in_channels, f_height, f_width, f_in_channels, f_out_channels, out_batch,
Jerry Gea793f462023-04-11 00:05:02 +0000743 out_height, out_width, out_channels, stride_y, stride_x, dilation_y, dilation_x, pad_top,
James Wardd34b3fc2023-01-18 14:51:25 +0000744 pad_bottom, pad_left, pad_right);
Eric Kunzee5e26762020-10-13 16:11:07 -0700745
746 // GEMM-conv2d, left matrix is input, right matrix is weight
747 Eigen::array<Eigen::Index, 2> im2col_input_dims;
748 im2col_input_dims[0] = out_batch * out_height * out_width;
749 im2col_input_dims[1] = f_height * f_width * f_in_channels;
750
751 Eigen::array<Eigen::Index, 2> im2col_weight_dims;
752 im2col_weight_dims[0] = f_height * f_width * f_in_channels;
753 im2col_weight_dims[1] = f_out_channels;
754
755 Eigen::array<Eigen::Index, 2> bias_reshaped_dims;
756 bias_reshaped_dims[0] = 1;
757 bias_reshaped_dims[1] = b_out_channels;
758
759 Eigen::array<Eigen::Index, 4> weight_zp_bcast_dims;
760 weight_zp_bcast_dims[0] = f_height;
761 weight_zp_bcast_dims[1] = f_width;
762 weight_zp_bcast_dims[2] = f_in_channels;
763
764 Eigen::array<Eigen::Index, 2> bias_bcast_dims;
765 bias_bcast_dims[0] = out_batch * out_height * out_width;
766 bias_bcast_dims[1] = 1;
767
768 Eigen::array<Eigen::Index, 4> col2im_output_dims;
769 col2im_output_dims[0] = out_batch;
770 col2im_output_dims[1] = out_height;
771 col2im_output_dims[2] = out_width;
772 col2im_output_dims[3] = out_channels;
773
774 Eigen::array<Eigen::IndexPair<Eigen::Index>, 1> contract_dims = { Eigen::IndexPair<Eigen::Index>(1, 0) };
775
TatWai Chong86c403b2022-06-06 20:46:01 -0700776 Eigen::array<std::pair<int32_t, int32_t>, 4> pad;
777 pad[0] = std::make_pair(0, 0);
778 pad[1] = std::make_pair(pad_top, pad_bottom);
779 pad[2] = std::make_pair(pad_left, pad_right);
780 pad[3] = std::make_pair(0, 0);
Eric Kunzee5e26762020-10-13 16:11:07 -0700781
782 TIn input_val = this->input->getTensor();
783 TWeight weight_val = this->weight->getTensor();
Eric Kunzef7337832022-06-17 08:19:12 -0700784 if (InDtype == DType_INT8 || WeightDtype == DType_INT8)
Eric Kunzee5e26762020-10-13 16:11:07 -0700785 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000786 input_val = input_val - (InEigenType)attribute->input_zp();
787 weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
Eric Kunzee5e26762020-10-13 16:11:07 -0700788 }
789
TatWai Chong86c403b2022-06-06 20:46:01 -0700790 ETensor4<InEigenType> input_padded = input_val.pad(pad);
Eric Kunzee5e26762020-10-13 16:11:07 -0700791
792 // extract_image_patches() output [N, KH, KW, H * W, C]
793 // need to transpose to [N, H * W, KH, KW, C]
794 ETensor5<InEigenType> input_extract_patches =
795 input_padded
Jerry Gea793f462023-04-11 00:05:02 +0000796 .extract_image_patches(f_height, f_width, stride_y, stride_x, dilation_y, dilation_x, Eigen::PADDING_VALID)
Eric Kunzee5e26762020-10-13 16:11:07 -0700797 .shuffle(Eigen::array<Eigen::Index, 5>{ 0, 3, 1, 2, 4 });
798
799 // reshape input to [N * H * W, KH * KW * C]
800 ETensor2<InEigenType> im2col_input = input_extract_patches.reshape(im2col_input_dims);
801
802 // transpose and reshape weight from [OC, H, W, IC] to [H * W * IC, OC]
803 ETensor2<WeightEigenType> im2col_weight =
James Ward8b390432022-08-12 20:48:56 +0100804 weight_val.shuffle(Eigen::array<Eigen::Index, 4>({ 1, 2, 3, 0 })).reshape(im2col_weight_dims);
Eric Kunzee5e26762020-10-13 16:11:07 -0700805
806 // don't need to apply bias_multiplier ( * bias_scale and >> bias_shift) since tflite already scale it
807 // and reshaped from [C] to [1, C], and broadcast to [N * H * W, C]
James Ward8b390432022-08-12 20:48:56 +0100808 ETensor2<OutEigenType> bias_2d = (this->bias->getTensor().reshape(bias_reshaped_dims).broadcast(bias_bcast_dims)).template cast<OutEigenType>();
Eric Kunzee5e26762020-10-13 16:11:07 -0700809
810 // output matrix is [N * H * W, C]
James Ward8b390432022-08-12 20:48:56 +0100811 ETensor2<OutEigenType> contracted_result =
812 (im2col_input.template cast<AccEigenType>().contract(im2col_weight.template cast<AccEigenType>(), contract_dims)).template cast<OutEigenType>();
Eric Kunzee5e26762020-10-13 16:11:07 -0700813
814 // adding bias
James Ward8b390432022-08-12 20:48:56 +0100815 ETensor2<OutEigenType> biased_output = contracted_result + bias_2d;
Eric Kunzee5e26762020-10-13 16:11:07 -0700816
817 // reshape back to [N, H, W, C]
818 this->output->getTensor() = biased_output.reshape(col2im_output_dims);
819
James Wardd34b3fc2023-01-18 14:51:25 +0000820 if (OutDtype == DType_INT48)
Eric Kunzee5e26762020-10-13 16:11:07 -0700821 {
James Ward8b390432022-08-12 20:48:56 +0100822 this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
823 this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
Eric Kunzee5e26762020-10-13 16:11:07 -0700824 }
825
826 return GraphNode::eval();
827}
828
James Wardd34b3fc2023-01-18 14:51:25 +0000829template <DType InDtype, DType WeightDtype, DType OutDtype>
830OpConv3d<InDtype, WeightDtype, OutDtype>::OpConv3d(SubgraphTraverser* sgt_,
Kevin Cheng1533b852021-09-01 12:51:58 -0700831 TosaAttributeBase* attribute_,
Kevin Cheng1533b852021-09-01 12:51:58 -0700832 uint64_t id_)
833 : GraphNode(sgt_, Op_CONV3D, id_)
834{
835 setRequiredOperands(3, 1);
836 setRequiredRank(5);
837
838 INIT_ATTRIBUTE(Conv);
Kevin Cheng1533b852021-09-01 12:51:58 -0700839}
840
James Wardd34b3fc2023-01-18 14:51:25 +0000841template <DType InDtype, DType WeightDtype, DType OutDtype>
842OpConv3d<InDtype, WeightDtype, OutDtype>::~OpConv3d()
Kevin Cheng1533b852021-09-01 12:51:58 -0700843{
844 if (attribute)
845 delete attribute;
Kevin Cheng1533b852021-09-01 12:51:58 -0700846}
847
James Wardd34b3fc2023-01-18 14:51:25 +0000848template <DType InDtype, DType WeightDtype, DType OutDtype>
849int OpConv3d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
Kevin Cheng1533b852021-09-01 12:51:58 -0700850{
851 if (validateRequiredOperands())
852 return 1;
853
854 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
855 {
856 return 1;
857 }
858
859 // 'bias' checked separatedly since it doens't make sense to make required rank ranging from 1 to 4
860 if (inputs[2]->getRank() != 1)
861 {
862 printNodeValidationError("OpConv3d: bias tensor must be rank 1");
863 }
864
James Wardd34b3fc2023-01-18 14:51:25 +0000865 ERROR_IF(outputs[0]->getDtype() != OutDtype,
James Ward8b390432022-08-12 20:48:56 +0100866 "OpConv3d: Output data type not supported for this configuration of operator");
Kevin Chengcc61be32021-10-14 17:09:57 -0700867
Kevin Cheng1533b852021-09-01 12:51:58 -0700868 input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
869 weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
870 bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
James Ward8b390432022-08-12 20:48:56 +0100871 output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
Kevin Cheng1533b852021-09-01 12:51:58 -0700872
Kevin Cheng9fe17242021-11-10 01:04:39 +0000873 std::string msg;
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000874 if (check_conv_attribute(attribute, 3 /* conv_dimension */, input->getShape(), output->getShape(),
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100875 weight->getShape(), 1 /* offset_kernel */, InDtype, WeightDtype, msg))
Kevin Cheng1533b852021-09-01 12:51:58 -0700876 {
Kevin Cheng9fe17242021-11-10 01:04:39 +0000877 msg = "OpConv3d: " + msg;
878 printNodeValidationError(msg.c_str());
Kevin Cheng1533b852021-09-01 12:51:58 -0700879 return 1;
880 }
881
Kevin Cheng1533b852021-09-01 12:51:58 -0700882 return 0;
883}
884
James Wardd34b3fc2023-01-18 14:51:25 +0000885template <DType InDtype, DType WeightDtype, DType OutDtype>
886int OpConv3d<InDtype, WeightDtype, OutDtype>::eval()
Kevin Cheng1533b852021-09-01 12:51:58 -0700887{
888 int in_batch = this->input->getShape()[0];
889 int in_depth = this->input->getShape()[1];
890 int in_height = this->input->getShape()[2];
891 int in_width = this->input->getShape()[3];
892 int in_channels = this->input->getShape()[4];
893
894 int f_out_channels = this->weight->getShape()[0];
895 int f_depth = this->weight->getShape()[1];
896 int f_height = this->weight->getShape()[2];
897 int f_width = this->weight->getShape()[3];
898 int f_in_channels = this->weight->getShape()[4];
899
900 int b_out_channels = this->bias->getShape()[0];
901
902 int out_batch = this->output->getShape()[0];
903 int out_depth = this->output->getShape()[1];
904 int out_height = this->output->getShape()[2];
905 int out_width = this->output->getShape()[3];
906 int out_channels = this->output->getShape()[4];
907
908 ERROR_IF(in_batch != out_batch, "OpConv3d: tensor batch mismatch %d != %d", in_batch, out_batch);
909 ERROR_IF(f_in_channels != in_channels, "OpConv3d: tensor input channel mismatch %d != %d", f_in_channels,
910 in_channels);
911 ERROR_IF(f_out_channels != out_channels, "OpConv3d: tensor output channel mismatch %d != %d", f_out_channels,
912 out_channels);
913 ERROR_IF(b_out_channels != out_channels, "OpConv3d: bias channel mismatch %d != %d", b_out_channels, out_channels);
914
TatWai Chong86c403b2022-06-06 20:46:01 -0700915 int pad_d0 = this->attribute->pad()[0];
916 int pad_d1 = this->attribute->pad()[1];
917 int pad_top = this->attribute->pad()[2];
918 int pad_bottom = this->attribute->pad()[3];
919 int pad_left = this->attribute->pad()[4];
920 int pad_right = this->attribute->pad()[5];
921
Kevin Cheng1533b852021-09-01 12:51:58 -0700922 int stride_d = this->attribute->stride()[0];
Jerry Gea793f462023-04-11 00:05:02 +0000923 int stride_y = this->attribute->stride()[1];
924 int stride_x = this->attribute->stride()[2];
TatWai Chong86c403b2022-06-06 20:46:01 -0700925
Kevin Cheng1533b852021-09-01 12:51:58 -0700926 int dilation_d = this->attribute->dilation()[0];
Jerry Gea793f462023-04-11 00:05:02 +0000927 int dilation_y = this->attribute->dilation()[1];
928 int dilation_x = this->attribute->dilation()[2];
929
930 // Check Tosa Level
931 auto tosa_level = g_func_config.tosa_level;
932 LEVEL_CHECK(dilation_d * f_depth <= tosa_level.MAX_KERNEL, "dilation_d * KD should be smaller than or equal to MAX_KERNEL");
933 LEVEL_CHECK(dilation_y * f_height <= tosa_level.MAX_KERNEL, "dilation_y * KH should be smaller than or equal to MAX_KERNEL");
934 LEVEL_CHECK(dilation_x * f_width <= tosa_level.MAX_KERNEL, "dilation_x * KW should be smaller than or equal to MAX_KERNEL");
935 LEVEL_CHECK(pad_d0 <= tosa_level.MAX_KERNEL, "pad_d0 should be smaller than or equal to MAX_KERNEL");
936 LEVEL_CHECK(pad_d1 <= tosa_level.MAX_KERNEL, "pad_d1 should be smaller than or equal to MAX_KERNEL");
937 LEVEL_CHECK(pad_top <= tosa_level.MAX_KERNEL, "pad_top should be smaller than or equal to MAX_KERNEL");
938 LEVEL_CHECK(pad_bottom <= tosa_level.MAX_KERNEL, "pad_bottom should be smaller than or equal to MAX_KERNEL");
939 LEVEL_CHECK(pad_left <= tosa_level.MAX_KERNEL, "pad_left should be smaller than or equal to MAX_KERNEL");
940 LEVEL_CHECK(pad_right <= tosa_level.MAX_KERNEL, "pad_right should be smaller than or equal to MAX_KERNEL");
941 LEVEL_CHECK(stride_y <= tosa_level.MAX_STRIDE, "stride_y should be smaller than or equal to MAX_STRIDE");
942 LEVEL_CHECK(stride_x <= tosa_level.MAX_STRIDE, "stride_x should be smaller than or equal to MAX_STRIDE");
943 LEVEL_CHECK(stride_d <= tosa_level.MAX_STRIDE, "stride_d should be smaller than or equal to MAX_STRIDE");
Kevin Cheng1533b852021-09-01 12:51:58 -0700944
945 DEBUG_INFO(
946 OP,
947 "perform OpConv3d, input.shape=[%d,%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d,%d], output.shape=[%d,%d,%d,%d,%d], "
James Wardd34b3fc2023-01-18 14:51:25 +0000948 "stride=[%d,%d,%d], dilation=[%d,%d,%d], pad=[%d,%d,%d,%d,%d,%d]",
Kevin Cheng1533b852021-09-01 12:51:58 -0700949 in_batch, in_depth, in_height, in_width, in_channels, f_out_channels, f_depth, f_height, f_width, f_in_channels,
Jerry Gea793f462023-04-11 00:05:02 +0000950 out_batch, out_depth, out_height, out_width, out_channels, stride_d, stride_y, stride_x, dilation_d, dilation_y,
951 dilation_x, pad_d0, pad_d1, pad_top, pad_bottom, pad_left, pad_right);
Kevin Cheng1533b852021-09-01 12:51:58 -0700952
TatWai Chong86c403b2022-06-06 20:46:01 -0700953 Eigen::array<std::pair<int32_t, int32_t>, 5> pad;
954 pad[0] = std::make_pair(0, 0);
955 pad[1] = std::make_pair(pad_d0, pad_d1);
956 pad[2] = std::make_pair(pad_top, pad_bottom);
957 pad[3] = std::make_pair(pad_left, pad_right);
958 pad[4] = std::make_pair(0, 0);
Kevin Cheng1533b852021-09-01 12:51:58 -0700959
960 TIn input_val = this->input->getTensor();
961 TWeight weight_val = this->weight->getTensor();
Eric Kunzef7337832022-06-17 08:19:12 -0700962 if (InDtype == DType_INT8 || WeightDtype == DType_INT8)
Kevin Cheng1533b852021-09-01 12:51:58 -0700963 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000964 input_val = input_val - (InEigenType)attribute->input_zp();
965 weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
Kevin Cheng1533b852021-09-01 12:51:58 -0700966 }
967
TatWai Chong86c403b2022-06-06 20:46:01 -0700968 ETensor5<InEigenType> input_padded = input_val.pad(pad);
Kevin Cheng1533b852021-09-01 12:51:58 -0700969
970 // 1. initialize with bias
971 Eigen::array<Eigen::Index, 5> reshape_dim;
972 reshape_dim.fill(1);
973 reshape_dim[4] = b_out_channels;
974
975 Eigen::array<Eigen::Index, 5> bcast;
976 bcast[0] = out_batch;
977 bcast[1] = out_depth;
978 bcast[2] = out_height;
979 bcast[3] = out_width;
980 bcast[4] = 1;
981 this->output->getTensor() = this->bias->getTensor().reshape(reshape_dim).broadcast(bcast);
982
983 // 2. direct convolution
James Ward8b390432022-08-12 20:48:56 +0100984 AccEigenType acc(0.0);
Kevin Cheng1533b852021-09-01 12:51:58 -0700985 int d_idx, h_idx, w_idx;
986
987 for (int ob = 0; ob < out_batch; ob++)
988 {
989 for (int od = 0; od < out_depth; od++)
990 {
991 for (int oh = 0; oh < out_height; oh++)
992 {
993 for (int ow = 0; ow < out_width; ow++)
994 {
995 for (int oc = 0; oc < out_channels; oc++)
996 {
Eric Kunze7edb34c2022-05-16 17:34:40 -0700997 // Initialize accumulator with bias value
James Ward8b390432022-08-12 20:48:56 +0100998 acc = (AccEigenType)this->output->getTensor()(ob, od, oh, ow, oc);
Kevin Cheng1533b852021-09-01 12:51:58 -0700999 for (int fd = 0; fd < f_depth; fd++)
1000 {
1001 d_idx = od * stride_d + fd * dilation_d;
1002 for (int fh = 0; fh < f_height; fh++)
1003 {
Jerry Gea793f462023-04-11 00:05:02 +00001004 h_idx = oh * stride_y + fh * dilation_y;
Kevin Cheng1533b852021-09-01 12:51:58 -07001005 for (int fw = 0; fw < f_width; fw++)
1006 {
Jerry Gea793f462023-04-11 00:05:02 +00001007 w_idx = ow * stride_x + fw * dilation_x;
Kevin Cheng1533b852021-09-01 12:51:58 -07001008 for (int ic = 0; ic < in_channels; ic++)
1009 {
1010 acc += ((AccEigenType)input_padded(ob, d_idx, h_idx, w_idx, ic) *
1011 (AccEigenType)weight_val(oc, fd, fh, fw, ic));
1012 }
1013 }
1014 }
1015 }
James Ward8b390432022-08-12 20:48:56 +01001016 this->output->getTensor()(ob, od, oh, ow, oc) = (OutEigenType)acc;
Kevin Cheng1533b852021-09-01 12:51:58 -07001017 }
1018 }
1019 }
1020 }
1021 }
1022
James Wardd34b3fc2023-01-18 14:51:25 +00001023 if (OutDtype == DType_INT48)
Kevin Cheng1533b852021-09-01 12:51:58 -07001024 {
James Ward8b390432022-08-12 20:48:56 +01001025 this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
1026 this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
Kevin Cheng1533b852021-09-01 12:51:58 -07001027 }
1028
1029 return GraphNode::eval();
1030}
1031
James Wardd34b3fc2023-01-18 14:51:25 +00001032template <DType InDtype, DType WeightDtype, DType OutDtype>
1033OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::OpDepthwiseConv2d(SubgraphTraverser* sgt_,
Kevin Chengacb550f2021-06-29 15:32:19 -07001034 TosaAttributeBase* attribute_,
Eric Kunzee5e26762020-10-13 16:11:07 -07001035 uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -07001036 : GraphNode(sgt_, Op_DEPTHWISE_CONV2D, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -07001037{
1038 setRequiredOperands(3, 1);
1039 setRequiredRank(4);
1040
Kevin Cheng93a16282021-08-31 16:14:03 -07001041 INIT_ATTRIBUTE(Conv);
Eric Kunzee5e26762020-10-13 16:11:07 -07001042}
1043
James Wardd34b3fc2023-01-18 14:51:25 +00001044template <DType InDtype, DType WeightDtype, DType OutDtype>
1045OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::~OpDepthwiseConv2d()
Eric Kunzee5e26762020-10-13 16:11:07 -07001046{
1047 if (attribute)
1048 delete attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -07001049}
1050
James Wardd34b3fc2023-01-18 14:51:25 +00001051template <DType InDtype, DType WeightDtype, DType OutDtype>
1052int OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -07001053{
1054 if (validateRequiredOperands())
1055 return 1;
1056
1057 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
1058 {
1059 return 1;
1060 }
1061
1062 // 'bias' checked separatedly since it doens't make sense to make required rank ranging from 1 to 4
1063 if (inputs[2]->getRank() != 1)
1064 {
1065 printNodeValidationError("OpDepthwiseConv2d: bias tensor must be rank 1");
1066 }
1067
James Wardd34b3fc2023-01-18 14:51:25 +00001068 ERROR_IF(outputs[0]->getDtype() != OutDtype,
James Ward8b390432022-08-12 20:48:56 +01001069 "OpDepthwiseConv2d: Output data type not supported for this configuration of operator");
Kevin Chengcc61be32021-10-14 17:09:57 -07001070
Eric Kunzee5e26762020-10-13 16:11:07 -07001071 input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
1072 weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
1073 bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
James Ward8b390432022-08-12 20:48:56 +01001074 output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
Eric Kunzee5e26762020-10-13 16:11:07 -07001075
Kevin Cheng9fe17242021-11-10 01:04:39 +00001076 std::string msg;
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001077 if (check_conv_attribute(attribute, 2 /* conv_dimension */, input->getShape(), output->getShape(),
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001078 weight->getShape(), 0 /* offset_kernel */, InDtype, WeightDtype, msg))
Eric Kunzee5e26762020-10-13 16:11:07 -07001079 {
Kevin Cheng9fe17242021-11-10 01:04:39 +00001080 msg = "OpDepthwiseConv2d: " + msg;
1081 printNodeValidationError(msg.c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -07001082 return 1;
1083 }
1084
Eric Kunzee5e26762020-10-13 16:11:07 -07001085 return 0;
1086}
1087
James Wardd34b3fc2023-01-18 14:51:25 +00001088template <DType InDtype, DType WeightDtype, DType OutDtype>
1089int OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -07001090{
1091 int in_batch = this->input->getShape()[0];
1092 int in_height = this->input->getShape()[1];
1093 int in_width = this->input->getShape()[2];
1094 int in_channels = this->input->getShape()[3];
1095
1096 int f_height = this->weight->getShape()[0];
1097 int f_width = this->weight->getShape()[1];
1098 int f_in_channels = this->weight->getShape()[2];
1099 int f_multiplier = this->weight->getShape()[3];
1100
1101 int b_out_channels = this->bias->getShape()[0];
1102
1103 int out_batch = this->output->getShape()[0];
1104 int out_height = this->output->getShape()[1];
1105 int out_width = this->output->getShape()[2];
1106 int out_channels = this->output->getShape()[3];
1107
Kevin Chengacb550f2021-06-29 15:32:19 -07001108 ERROR_IF(in_batch != out_batch, "OpDepthwiseConv2d: tensor batch mismatch %d != %d", in_batch, out_batch);
1109 ERROR_IF(f_in_channels != in_channels, "OpDepthwiseConv2d: tensor input channel mismatch %d != %d", f_in_channels,
1110 in_channels);
1111 ERROR_IF(in_channels * f_multiplier != out_channels, "OpDepthwiseConv2d: tensor output channel mismatch %d != %d",
1112 in_channels * f_multiplier, out_channels);
1113 ERROR_IF(b_out_channels != out_channels, "OpDepthwiseConv2d: bias channels mismatch %d != %d", b_out_channels,
1114 out_channels);
Eric Kunzee5e26762020-10-13 16:11:07 -07001115
TatWai Chong86c403b2022-06-06 20:46:01 -07001116 int pad_top = this->attribute->pad()[0];
1117 int pad_bottom = this->attribute->pad()[1];
1118 int pad_left = this->attribute->pad()[2];
1119 int pad_right = this->attribute->pad()[3];
1120
Jerry Gea793f462023-04-11 00:05:02 +00001121 int stride_y = this->attribute->stride()[0];
1122 int stride_x = this->attribute->stride()[1];
1123 int dilation_y = this->attribute->dilation()[0];
1124 int dilation_x = this->attribute->dilation()[1];
1125
1126 // Check Tosa Level
1127 auto tosa_level = g_func_config.tosa_level;
1128 LEVEL_CHECK(dilation_y * f_height <= tosa_level.MAX_KERNEL, "dilation_y * KH should be smaller than or equal to MAX_KERNEL");
1129 LEVEL_CHECK(dilation_x * f_width <= tosa_level.MAX_KERNEL, "dilation_x * KW should be smaller than or equal to MAX_KERNEL");
1130 LEVEL_CHECK(pad_top <= tosa_level.MAX_KERNEL, "pad_top should be smaller than or equal to MAX_KERNEL");
1131 LEVEL_CHECK(pad_bottom <= tosa_level.MAX_KERNEL, "pad_bottom should be smaller than or equal to MAX_KERNEL");
1132 LEVEL_CHECK(pad_left <= tosa_level.MAX_KERNEL, "pad_left should be smaller than or equal to MAX_KERNEL");
1133 LEVEL_CHECK(pad_right <= tosa_level.MAX_KERNEL, "pad_right should be smaller than or equal to MAX_KERNEL");
1134 LEVEL_CHECK(stride_y <= tosa_level.MAX_STRIDE, "stride_y should be smaller than or equal to MAX_STRIDE");
1135 LEVEL_CHECK(stride_x <= tosa_level.MAX_STRIDE, "stride_x should be smaller than or equal to MAX_STRIDE");
Eric Kunzee5e26762020-10-13 16:11:07 -07001136
1137 DEBUG_INFO(OP,
1138 "perform OpDepthwiseConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], "
James Wardd34b3fc2023-01-18 14:51:25 +00001139 "output.shape=[%d,%d,%d,%d], stride=[%d,%d], dilation=[%d,%d], pad=[%d,%d,%d,%d]",
Eric Kunzee5e26762020-10-13 16:11:07 -07001140 in_batch, in_height, in_width, in_channels, f_height, f_width, f_in_channels, f_multiplier, out_batch,
Jerry Gea793f462023-04-11 00:05:02 +00001141 out_height, out_width, out_channels, stride_y, stride_x, dilation_y, dilation_x, pad_top,
James Wardd34b3fc2023-01-18 14:51:25 +00001142 pad_bottom, pad_left, pad_right);
Eric Kunzee5e26762020-10-13 16:11:07 -07001143
TatWai Chong86c403b2022-06-06 20:46:01 -07001144 Eigen::array<std::pair<int32_t, int32_t>, 4> pad;
1145 pad[0] = std::make_pair(0, 0);
1146 pad[1] = std::make_pair(pad_top, pad_bottom);
1147 pad[2] = std::make_pair(pad_left, pad_right);
1148 pad[3] = std::make_pair(0, 0);
Eric Kunzee5e26762020-10-13 16:11:07 -07001149
1150 TIn input_val = this->input->getTensor();
1151 TWeight weight_val = this->weight->getTensor();
Eric Kunzef7337832022-06-17 08:19:12 -07001152 if (InDtype == DType_INT8 || WeightDtype == DType_INT8)
Eric Kunzee5e26762020-10-13 16:11:07 -07001153 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001154 input_val = input_val - (InEigenType)attribute->input_zp();
1155 weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
Eric Kunzee5e26762020-10-13 16:11:07 -07001156 }
1157
TatWai Chong86c403b2022-06-06 20:46:01 -07001158 ETensor4<InEigenType> input_padded = input_val.pad(pad);
Eric Kunzee5e26762020-10-13 16:11:07 -07001159
1160 // GEMM doesn't fit well with DepthwiseConv2d
TatWai Chong86c403b2022-06-06 20:46:01 -07001161 // 1. use extract_image_patches() to handle stride/dilation/pad
Eric Kunzee5e26762020-10-13 16:11:07 -07001162 // 2. perform direct convolution
1163
1164 // 1. extract_image_patches() output [N, KH, KW, OH * OW, IC]
1165 ETensor5<InEigenType> input_extract_patches = input_padded.extract_image_patches(
Jerry Gea793f462023-04-11 00:05:02 +00001166 f_height, f_width, stride_y, stride_x, dilation_y, dilation_x, Eigen::PADDING_VALID);
Eric Kunzee5e26762020-10-13 16:11:07 -07001167
1168 Eigen::array<Eigen::Index, 4> reshape_dim;
1169 reshape_dim.fill(1);
1170 reshape_dim[3] = b_out_channels;
1171
1172 Eigen::array<Eigen::Index, 4> bcast;
1173 bcast[0] = out_batch;
1174 bcast[1] = out_height;
1175 bcast[2] = out_width;
1176 bcast[3] = 1;
1177
1178 // initialize with bias
1179 this->output->getTensor() = this->bias->getTensor().reshape(reshape_dim).broadcast(bcast);
1180
1181 // 2. direct depthwise convolution
1182 for (int ob = 0; ob < out_batch; ob++)
1183 {
1184 for (int oh = 0; oh < out_height; oh++)
1185 {
1186 for (int ow = 0; ow < out_width; ow++)
1187 {
1188 for (int ic = 0; ic < in_channels; ic++)
1189 {
1190 for (int cm = 0; cm < f_multiplier; cm++)
1191 {
1192 for (int fh = 0; fh < f_height; fh++)
1193 {
1194 for (int fw = 0; fw < f_width; fw++)
1195 {
James Ward8b390432022-08-12 20:48:56 +01001196 // Perform multiplication in AccEigenType then cast to OutEigenType
Eric Kunzee5e26762020-10-13 16:11:07 -07001197 this->output->getTensor()(ob, oh, ow, ic * f_multiplier + cm) +=
James Ward8b390432022-08-12 20:48:56 +01001198 (OutEigenType)((AccEigenType)input_extract_patches(ob, fh, fw, ow * out_height + oh, ic) *
1199 (AccEigenType)weight_val(fh, fw, ic, cm));
Eric Kunzee5e26762020-10-13 16:11:07 -07001200 }
1201 }
1202 }
1203 }
1204 }
1205 }
1206 }
1207
James Wardd34b3fc2023-01-18 14:51:25 +00001208 if (OutDtype == DType_INT48)
Eric Kunzee5e26762020-10-13 16:11:07 -07001209 {
James Ward8b390432022-08-12 20:48:56 +01001210 this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
1211 this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
Eric Kunzee5e26762020-10-13 16:11:07 -07001212 }
1213
1214 return GraphNode::eval();
1215}
1216
James Wardd34b3fc2023-01-18 14:51:25 +00001217template <DType InDtype, DType WeightDtype, DType OutDtype>
1218OpFullyConnected<InDtype, WeightDtype, OutDtype>::OpFullyConnected(SubgraphTraverser* sgt_,
Kevin Chengacb550f2021-06-29 15:32:19 -07001219 TosaAttributeBase* attribute_,
Eric Kunzee5e26762020-10-13 16:11:07 -07001220 uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -07001221 : GraphNode(sgt_, Op_FULLY_CONNECTED, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -07001222{
1223 setRequiredOperands(3, 1);
1224 setRequiredRank(2);
1225
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001226 INIT_ATTRIBUTE(FullyConnected);
Eric Kunzee5e26762020-10-13 16:11:07 -07001227}
1228
James Wardd34b3fc2023-01-18 14:51:25 +00001229template <DType InDtype, DType WeightDtype, DType OutDtype>
1230OpFullyConnected<InDtype, WeightDtype, OutDtype>::~OpFullyConnected()
Eric Kunzee5e26762020-10-13 16:11:07 -07001231{
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001232 if (attribute)
1233 delete attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -07001234}
1235
James Wardd34b3fc2023-01-18 14:51:25 +00001236template <DType InDtype, DType WeightDtype, DType OutDtype>
1237int OpFullyConnected<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -07001238{
1239 if (validateRequiredOperands())
1240 return 1;
1241
1242 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
1243 {
1244 return 1;
1245 }
1246
1247 input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
1248 weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
1249 bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
1250
1251 if (input->getShape()[1] != weight->getShape()[1])
1252 {
1253 printNodeValidationError("OpFullyConnected operator input.shape[1] should match weight.shape[1]");
1254 return 1;
1255 }
1256
1257 if (weight->getShape()[0] != bias->getShape()[0])
1258 {
1259 printNodeValidationError("OpFullyConnected operator bias.shape[0] should match weight.shape[0]");
1260 return 1;
1261 }
1262
James Wardd34b3fc2023-01-18 14:51:25 +00001263 ERROR_IF(outputs[0]->getDtype() != OutDtype,
James Ward8b390432022-08-12 20:48:56 +01001264 "OpFullyConnected: Output data type not supported for this configuration of operator");
Kevin Chengcc61be32021-10-14 17:09:57 -07001265
James Ward8b390432022-08-12 20:48:56 +01001266 output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
Eric Kunzee5e26762020-10-13 16:11:07 -07001267
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001268 ERROR_IF(InDtype != DType_INT8 && attribute->input_zp() != 0, "OpFullyConnected: Input zeropoint must be zero for non int8_t data");
1269 ERROR_IF(WeightDtype != DType_INT8 && attribute->weight_zp() != 0, "OpFullyConnected: Weight zeropoint must be zero for non int8_t data");
Kevin Chengcc61be32021-10-14 17:09:57 -07001270
Eric Kunzee5e26762020-10-13 16:11:07 -07001271 return 0;
1272}
1273
James Wardd34b3fc2023-01-18 14:51:25 +00001274template <DType InDtype, DType WeightDtype, DType OutDtype>
1275int OpFullyConnected<InDtype, WeightDtype, OutDtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -07001276{
1277 typedef Eigen::Tensor<int, 1>::DimensionPair DimPair;
1278 Eigen::array<DimPair, 1> dims{ { DimPair(1, 0) } };
1279
1280 Eigen::array<Eigen::Index, 2> weight_shuffle{ 1, 0 };
1281
1282 Eigen::array<Eigen::Index, 2> bias_reshape;
1283 bias_reshape[0] = 1;
1284 bias_reshape[1] = this->bias->getShape()[0];
1285
1286 Eigen::array<Eigen::Index, 2> bias_bcast;
1287 bias_bcast[0] = this->input->getShape()[0];
1288 bias_bcast[1] = 1;
1289
1290 TIn input_val = this->input->getTensor();
1291 TWeight weight_val = this->weight->getTensor().shuffle(weight_shuffle);
Eric Kunzef7337832022-06-17 08:19:12 -07001292 if (InDtype == DType_INT8 || WeightDtype == DType_INT8)
Eric Kunzee5e26762020-10-13 16:11:07 -07001293 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001294 input_val = input_val - (InEigenType)attribute->input_zp();
1295 weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
Eric Kunzee5e26762020-10-13 16:11:07 -07001296 }
1297
1298 this->output->getTensor() =
James Ward8b390432022-08-12 20:48:56 +01001299 input_val.template cast<AccEigenType>().contract(weight_val.template cast<AccEigenType>(), dims).template cast<OutEigenType>() +
1300 this->bias->getTensor().reshape(bias_reshape).broadcast(bias_bcast);
Eric Kunzee5e26762020-10-13 16:11:07 -07001301
James Wardd34b3fc2023-01-18 14:51:25 +00001302 if (OutDtype == DType_INT48)
Eric Kunzee5e26762020-10-13 16:11:07 -07001303 {
James Ward8b390432022-08-12 20:48:56 +01001304 this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
1305 this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
Eric Kunzee5e26762020-10-13 16:11:07 -07001306 }
1307 return GraphNode::eval();
1308}
1309
James Wardd34b3fc2023-01-18 14:51:25 +00001310template <DType Dtype, DType OutDtype>
1311OpMatMul<Dtype, OutDtype>::OpMatMul(SubgraphTraverser* sgt_,
Kevin Chengacb550f2021-06-29 15:32:19 -07001312 TosaAttributeBase* attribute_,
Kevin Chengacb550f2021-06-29 15:32:19 -07001313 uint64_t id_)
1314 : GraphNode(sgt_, Op_MATMUL, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -07001315{
1316 setRequiredOperands(2, 1);
Kevin Cheng2d60f002021-06-09 14:18:32 -07001317 setRequiredRank(3);
Eric Kunzee5e26762020-10-13 16:11:07 -07001318
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001319 INIT_ATTRIBUTE(MatMul);
Eric Kunzee5e26762020-10-13 16:11:07 -07001320}
1321
James Wardd34b3fc2023-01-18 14:51:25 +00001322template <DType Dtype, DType OutDtype>
1323OpMatMul<Dtype, OutDtype>::~OpMatMul()
Eric Kunzee5e26762020-10-13 16:11:07 -07001324{
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001325 if (attribute)
1326 delete attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -07001327}
1328
James Wardd34b3fc2023-01-18 14:51:25 +00001329template <DType Dtype, DType OutDtype>
1330int OpMatMul<Dtype, OutDtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -07001331{
1332 if (validateRequiredOperands())
1333 return 1;
1334
1335 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
1336 {
1337 return 1;
1338 }
1339
James Wardd34b3fc2023-01-18 14:51:25 +00001340 ERROR_IF(outputs[0]->getDtype() != OutDtype,
James Ward8b390432022-08-12 20:48:56 +01001341 "OpMatMul: Output data type not supported for this configuration of operator");
Kevin Chengcc61be32021-10-14 17:09:57 -07001342
Kevin Cheng2d60f002021-06-09 14:18:32 -07001343 a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
1344 b = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
James Ward8b390432022-08-12 20:48:56 +01001345 output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
Eric Kunzee5e26762020-10-13 16:11:07 -07001346
Kevin Cheng2d60f002021-06-09 14:18:32 -07001347 ASSERT_MEM(a && b && output);
1348
1349 // a: [N, H, C]
1350 // b: [N, C, W]
1351 // c: [N, H, W]
1352
1353 // Check N
1354 if (a->getShape()[0] != b->getShape()[0] || a->getShape()[0] != output->getShape()[0])
Eric Kunzee5e26762020-10-13 16:11:07 -07001355 {
Kevin Cheng2d60f002021-06-09 14:18:32 -07001356 printNodeValidationError("OpMatMul operator a.shape[0], b.shape[0] and output.shape[0] should match");
Eric Kunzee5e26762020-10-13 16:11:07 -07001357 return 1;
1358 }
Kevin Cheng2d60f002021-06-09 14:18:32 -07001359 N = a->getShape()[0];
Eric Kunzee5e26762020-10-13 16:11:07 -07001360
Kevin Cheng2d60f002021-06-09 14:18:32 -07001361 // Check C
1362 if (a->getShape()[2] != b->getShape()[1])
1363 {
1364 printNodeValidationError("OpMatMul operator a.shape[2] should match b.shape[1]");
1365 return 1;
1366 }
1367 C = a->getShape()[2];
1368
1369 // Check H
1370 if (a->getShape()[1] != output->getShape()[1])
1371 {
1372 printNodeValidationError("OpMatMul operator a.shape[1] should match output.shape[1]");
1373 return 1;
1374 }
1375 H = a->getShape()[1];
1376
1377 // Check W
1378 if (b->getShape()[2] != output->getShape()[2])
1379 {
1380 printNodeValidationError("OpMatMul operator output.shape[2] should match output.shape[2]");
1381 return 1;
1382 }
1383 W = b->getShape()[2];
Eric Kunzee5e26762020-10-13 16:11:07 -07001384
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001385 ERROR_IF(Dtype != DType_INT8 && attribute->a_zp() != 0, "OpMatMul: A zeropoint must be zero for non int8_t data");
1386 ERROR_IF(Dtype != DType_INT8 && attribute->b_zp() != 0, "OpMatMul: B zeropoint must be zero for non int8_t data");
Kevin Chengcc61be32021-10-14 17:09:57 -07001387
Eric Kunzee5e26762020-10-13 16:11:07 -07001388 return 0;
1389}
1390
James Wardd34b3fc2023-01-18 14:51:25 +00001391template <DType Dtype, DType OutDtype>
1392int OpMatMul<Dtype, OutDtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -07001393{
1394 typedef Eigen::Tensor<int, 1>::DimensionPair DimPair;
1395 Eigen::array<DimPair, 1> dims{ { DimPair(1, 0) } };
1396
1397 TIn a_val = this->a->getTensor();
1398 TIn b_val = this->b->getTensor();
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001399 if (Dtype == DType_INT8)
Eric Kunzee5e26762020-10-13 16:11:07 -07001400 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001401 a_val = a_val - (InEigenType)attribute->a_zp();
1402 b_val = b_val - (InEigenType)attribute->b_zp();
Eric Kunzee5e26762020-10-13 16:11:07 -07001403 }
1404
Kevin Cheng2d60f002021-06-09 14:18:32 -07001405 Eigen::array<Eigen::Index, 2> a_rank2_shape({ H, C });
1406 Eigen::array<Eigen::Index, 2> b_rank2_shape({ C, W });
1407 Eigen::array<Eigen::Index, 3> output_rank3_shape({ 1, H, W });
1408
1409 Eigen::array<Eigen::Index, 3> a_size_array({ 1, H, C });
1410 Eigen::array<Eigen::Index, 3> b_size_array({ 1, C, W });
1411
1412 Eigen::array<Eigen::Index, 3> a_begin_array({ 0, 0, 0 });
1413 Eigen::array<Eigen::Index, 3> b_begin_array({ 0, 0, 0 });
1414
1415 // Iterate N dimension.
1416 for (int i = 0; i < N; i++)
1417 {
1418 a_begin_array[0] = i;
1419 b_begin_array[0] = i;
1420
1421 TInRank2 a_rank2_val = a_val.slice(a_begin_array, a_size_array).reshape(a_rank2_shape);
1422 TInRank2 b_rank2_val = b_val.slice(b_begin_array, b_size_array).reshape(b_rank2_shape);
1423 TAccRank2 output_rank2_val =
1424 a_rank2_val.template cast<AccEigenType>().contract(b_rank2_val.template cast<AccEigenType>(), dims);
James Ward8b390432022-08-12 20:48:56 +01001425 TOut output_rank3_val = output_rank2_val.reshape(output_rank3_shape).template cast<OutEigenType>();
Kevin Cheng2d60f002021-06-09 14:18:32 -07001426 if (i == 0)
1427 {
1428 this->output->getTensor() = output_rank3_val;
1429 }
1430 else
1431 {
James Ward8b390432022-08-12 20:48:56 +01001432 TOut temp = this->output->getTensor().concatenate(output_rank3_val, 0);
Kevin Cheng2d60f002021-06-09 14:18:32 -07001433 this->output->getTensor() = temp;
1434 }
1435 }
Eric Kunzee5e26762020-10-13 16:11:07 -07001436
James Wardd34b3fc2023-01-18 14:51:25 +00001437 if (OutDtype == DType_INT48)
Eric Kunzee5e26762020-10-13 16:11:07 -07001438 {
James Ward8b390432022-08-12 20:48:56 +01001439 this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
1440 this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
Eric Kunzee5e26762020-10-13 16:11:07 -07001441 }
1442
1443 return GraphNode::eval();
1444}
1445
1446template <DType Dtype>
Kevin Chengacb550f2021-06-29 15:32:19 -07001447OpMaxPool2d<Dtype>::OpMaxPool2d(SubgraphTraverser* sgt_,
1448 TosaAttributeBase* attribute_,
Kevin Chengacb550f2021-06-29 15:32:19 -07001449 uint64_t id_)
1450 : GraphNode(sgt_, Op_MAX_POOL2D, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -07001451{
1452 setRequiredOperands(1, 1);
1453 setRequiredRank(4);
1454
Kevin Cheng93a16282021-08-31 16:14:03 -07001455 INIT_ATTRIBUTE(Pool);
Eric Kunzee5e26762020-10-13 16:11:07 -07001456}
1457
1458template <DType Dtype>
1459OpMaxPool2d<Dtype>::~OpMaxPool2d()
1460{
1461 if (attribute)
1462 delete attribute;
1463}
1464
1465template <DType Dtype>
1466int OpMaxPool2d<Dtype>::checkTensorAttributes()
1467{
1468 if (validateRequiredOperands())
1469 return 1;
1470
1471 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
1472 {
1473 return 1;
1474 }
1475
1476 if (inputs[0]->matchType(*outputs[0]))
1477 {
1478 printNodeValidationError("OpMaxPool2d: input and output tensor type mismatch");
1479 return 1;
1480 }
1481
1482 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
1483 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
1484
Kevin Cheng7eb93d72021-10-09 01:26:08 +00001485 std::string msg;
Kevin Cheng9fe17242021-11-10 01:04:39 +00001486 if (check_pool2d_attribute(attribute, in->getShape(), out->getShape(), msg))
Eric Kunzee5e26762020-10-13 16:11:07 -07001487 {
Kevin Cheng7eb93d72021-10-09 01:26:08 +00001488 msg = "OpMaxPool2d: " + msg;
1489 printNodeValidationError(msg.c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -07001490 return 1;
1491 }
1492
1493 return 0;
1494}
1495
1496template <DType Dtype>
1497int OpMaxPool2d<Dtype>::eval()
1498{
1499 int in_batch = this->in->getShape()[0];
1500 int in_height = this->in->getShape()[1];
1501 int in_width = this->in->getShape()[2];
1502 int in_channels = this->in->getShape()[3];
1503
1504 int out_batch = this->out->getShape()[0];
1505 int out_height = this->out->getShape()[1];
1506 int out_width = this->out->getShape()[2];
1507 int out_channels = this->out->getShape()[3];
1508
Kevin Chengacb550f2021-06-29 15:32:19 -07001509 ERROR_IF(in_batch != out_batch, "OpMaxPool2d: tensor batch mismatch %d != %d", in_batch, out_batch);
1510 ERROR_IF(in_channels != out_channels, "OpMaxPool2d: tensor channel mismatch %d != %d", in_channels, out_channels);
Eric Kunzee5e26762020-10-13 16:11:07 -07001511
TatWai Chong86c403b2022-06-06 20:46:01 -07001512 int pad_top = this->attribute->pad()[0];
1513 int pad_bottom = this->attribute->pad()[1];
1514 int pad_left = this->attribute->pad()[2];
1515 int pad_right = this->attribute->pad()[3];
1516
Jerry Gea793f462023-04-11 00:05:02 +00001517 int kernel_y = this->attribute->kernel()[0];
1518 int kernel_x = this->attribute->kernel()[1];
1519 int stride_y = this->attribute->stride()[0];
1520 int stride_x = this->attribute->stride()[1];
1521
1522 // Check Tosa Level
1523 auto tosa_level = g_func_config.tosa_level;
1524 LEVEL_CHECK(kernel_y <= tosa_level.MAX_KERNEL, "kernel_y should be smaller than or equal to MAX_KERNEL");
1525 LEVEL_CHECK(kernel_x <= tosa_level.MAX_KERNEL, "kernel_x should be smaller than or equal to MAX_KERNEL");
1526 LEVEL_CHECK(stride_y <= tosa_level.MAX_STRIDE, "stride_y should be smaller than or equal to MAX_STRIDE");
1527 LEVEL_CHECK(stride_x <= tosa_level.MAX_STRIDE, "stride_x should be smaller than or equal to MAX_STRIDE");
1528 LEVEL_CHECK(pad_top <= tosa_level.MAX_KERNEL, "pad_top should be smaller than or equal to MAX_KERNEL");
1529 LEVEL_CHECK(pad_bottom <= tosa_level.MAX_KERNEL, "pad_bottom should be smaller than or equal to MAX_KERNEL");
1530 LEVEL_CHECK(pad_left <= tosa_level.MAX_KERNEL, "pad_left should be smaller than or equal to MAX_KERNEL");
1531 LEVEL_CHECK(pad_right <= tosa_level.MAX_KERNEL, "pad_right should be smaller than or equal to MAX_KERNEL");
Eric Kunzee5e26762020-10-13 16:11:07 -07001532
1533 DEBUG_INFO(OP,
1534 "perform MaxPool2d, input.shape=[%d,%d,%d,%d], output.shape=[%d,%d,%d,%d], kernel=[%d,%d], "
TatWai Chong86c403b2022-06-06 20:46:01 -07001535 "stride=[%d,%d], pad=[%d,%d,%d,%d]",
Jerry Gea793f462023-04-11 00:05:02 +00001536 in_batch, in_height, in_width, in_channels, out_batch, out_height, out_width, out_channels, kernel_y,
1537 kernel_x, stride_y, stride_x, pad_top, pad_bottom, pad_left, pad_right);
Eric Kunzee5e26762020-10-13 16:11:07 -07001538
1539 Eigen::array<Eigen::Index, 2> im2col_input_dims;
Jerry Gea793f462023-04-11 00:05:02 +00001540 im2col_input_dims[0] = kernel_y * kernel_x;
Eric Kunzee5e26762020-10-13 16:11:07 -07001541 im2col_input_dims[1] = out_batch * out_height * out_width * out_channels;
1542
1543 Eigen::array<Eigen::Index, 4> col2im_output_dims;
1544 col2im_output_dims[0] = out_batch;
1545 col2im_output_dims[1] = out_height;
1546 col2im_output_dims[2] = out_width;
1547 col2im_output_dims[3] = out_channels;
1548
TatWai Chong86c403b2022-06-06 20:46:01 -07001549 Eigen::array<std::pair<int32_t, int32_t>, 4> pad;
1550 pad[0] = std::make_pair(0, 0);
1551 pad[1] = std::make_pair(pad_top, pad_bottom);
1552 pad[2] = std::make_pair(pad_left, pad_right);
1553 pad[3] = std::make_pair(0, 0);
Eric Kunzee5e26762020-10-13 16:11:07 -07001554
TatWai Chong86c403b2022-06-06 20:46:01 -07001555 ETensor4<InEigenType> input_padded = this->in->getTensor().pad(pad, std::numeric_limits<InEigenType>::lowest());
Eric Kunzee5e26762020-10-13 16:11:07 -07001556
1557 // extract_image_patches() output [N, KH, KW, H * W, C]
1558 // transpose to [KH, KW, N, H * W, C]
1559 // reshape to [KH * KW, N * H * W * C]
1560 //
1561 // Set the padding value to be the most negative value that can be
1562 // represented by the datatype to ensure that any padding values will be equal
1563 // to or smaller than the actual maximum in the KH x KW patch.
1564 ETensor2<InEigenType> input_extract_patches =
1565 input_padded
Jerry Gea793f462023-04-11 00:05:02 +00001566 .extract_image_patches(kernel_y, kernel_x, stride_y, stride_x, 1, 1, Eigen::PADDING_VALID,
Eric Kunzee5e26762020-10-13 16:11:07 -07001567 std::numeric_limits<InEigenType>::lowest())
1568 .shuffle(Eigen::array<Eigen::Index, 5>{ 1, 2, 0, 3, 4 })
1569 .reshape(im2col_input_dims);
1570
1571 // Get the maximum of the KHxHW patches along axis 0
1572 Eigen::Tensor<DenseIndex, 1> tensor_argmax = input_extract_patches.argmax(0);
1573
1574 // 1D result with [N * H * W * C]
1575 ETensor1<OutEigenType> out_1d(this->out->getElementCount());
1576
1577 // index input_patches with argmax array should give the result
1578 for (size_t i = 0; i < this->out->getElementCount(); i++)
1579 {
1580 out_1d(i) = (OutEigenType)input_extract_patches(tensor_argmax(i), i);
1581 }
1582
1583 // reshape result to [N, H, W, C]
1584 this->out->getTensor() = out_1d.reshape(col2im_output_dims);
1585
1586 return GraphNode::eval();
1587}
1588
Luke Hutton261b7b62023-01-10 14:50:31 +00001589template <DType Dtype>
Luke Hutton57287132023-02-06 14:54:18 +00001590OpFFT2d<Dtype>::OpFFT2d(SubgraphTraverser* sgt_,
1591 TosaAttributeBase* attribute_,
1592 uint64_t id_)
1593 : GraphNode(sgt_, Op_FFT2D, id_)
1594{
1595 setRequiredOperands(2, 2);
1596 setRequiredRank(3);
1597
1598 INIT_ATTRIBUTE(FFT);
1599}
1600
1601template <DType Dtype>
1602OpFFT2d<Dtype>::~OpFFT2d() {
1603 if (attribute)
1604 delete attribute;
1605}
1606
1607
1608template <DType Dtype>
1609int OpFFT2d<Dtype>::checkTensorAttributes()
1610{
1611 if (validateRequiredOperands())
1612 return 1;
1613
1614 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) ||
1615 validateRequiredRank(outputs[0]) || validateRequiredRank(outputs[1]))
1616 {
1617 return 1;
1618 }
1619
1620 if (inputs[0]->matchType(*outputs[0]) || inputs[1]->matchType(*outputs[1]) ||
1621 inputs[0]->matchType(*inputs[1]))
1622 {
1623 printNodeValidationError("OpFFT2d: input and output tensor type mismatch");
1624 return 1;
1625 }
1626
1627 in_real = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
1628 in_imag = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
1629 out_real = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
1630 out_imag = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[1]);
1631
1632 ASSERT_MEM(in_real && in_imag && out_real && out_imag);
1633
1634 std::string msg;
1635 if (check_fft_shape(in_real->getShape(), in_imag->getShape(),
1636 out_real->getShape(), out_imag->getShape(), msg))
1637 {
1638 msg = "OpFFT2d: " + msg;
1639 printNodeValidationError(msg.c_str());
1640 return 1;
1641 }
1642
1643 return 0;
1644}
1645
1646template <DType Dtype>
1647int OpFFT2d<Dtype>::eval()
1648{
1649 int in_real_batch = this->in_real->getShape()[0];
1650 int in_real_height = this->in_real->getShape()[1];
1651 int in_real_width = this->in_real->getShape()[2];
1652
1653 int in_imag_batch = this->in_imag->getShape()[0];
1654 int in_imag_height = this->in_imag->getShape()[1];
1655 int in_imag_width = this->in_imag->getShape()[2];
1656
1657 int out_real_batch = this->out_real->getShape()[0];
1658 int out_real_height = this->out_real->getShape()[1];
1659 int out_real_width = this->out_real->getShape()[2];
1660
1661 int out_imag_batch = this->out_imag->getShape()[0];
1662 int out_imag_height = this->out_imag->getShape()[1];
1663 int out_imag_width = this->out_imag->getShape()[2];
1664
Jerry Gea793f462023-04-11 00:05:02 +00001665 // Check Tosa Level
1666 auto tosa_level = g_func_config.tosa_level;
1667 LEVEL_CHECK(in_real_height <= tosa_level.MAX_KERNEL, "H should be smaller than or equal to MAX_KERNEL");
1668 LEVEL_CHECK(in_real_width <= tosa_level.MAX_KERNEL, "W should be smaller than or equal to MAX_KERNEL");
1669
Luke Hutton57287132023-02-06 14:54:18 +00001670 DEBUG_INFO(OP,
1671 "perform OpFFT2d, input.shapes=[[%d,%d,%d],[%d,%d,%d]], output.shapes=[[%d,%d,%d],[%d,%d,%d]]",
1672 in_real_batch, in_real_height, in_real_width,
1673 in_imag_batch, in_imag_height, in_imag_width,
1674 out_real_batch, out_real_height, out_real_width,
1675 out_imag_batch, out_imag_height, out_imag_width);
1676
1677 OutEigenType sum_real, sum_imag, a, sign_val = 1.0;
1678
1679 if (attribute->inverse()) {
1680 sign_val = -1.0;
1681 }
1682
1683 for (int n = 0; n < in_real_batch; n++)
1684 {
1685 for (int oy = 0; oy < out_real_height; oy++)
1686 {
1687 for (int ox = 0; ox < out_real_width; ox++)
1688 {
1689 sum_real = 0.0;
1690 sum_imag = 0.0;
1691 for (int iy = 0; iy < in_real_height; iy++)
1692 {
1693 for (int ix = 0; ix < in_real_width; ix++)
1694 {
1695 OutEigenType val_real = this->in_real->getTensor()(n, iy, ix);
1696 OutEigenType val_imag = this->in_imag->getTensor()(n, iy, ix);
1697 // Use explicit cast to ensure intermmediate calculations are completed using OutEigenType
1698 a = sign_val * 2 * M_PI * ((iy * (OutEigenType)oy) / in_real_height + (ix * (OutEigenType)ox) / in_real_width);
1699 sum_real += val_real * cos(a) + val_imag * sin(a);
1700 sum_imag += -val_real * sin(a) + val_imag * cos(a);
1701 }
1702 }
1703 this->out_real->getTensor()(n, oy, ox) = sum_real;
1704 this->out_imag->getTensor()(n, oy, ox) = sum_imag;
1705 }
1706 }
1707 }
1708
1709 return GraphNode::eval();
1710}
1711
1712template <DType Dtype>
Luke Hutton261b7b62023-01-10 14:50:31 +00001713OpRFFT2d<Dtype>::OpRFFT2d(SubgraphTraverser* sgt_,
1714 TosaAttributeBase* attribute_,
1715 uint64_t id_)
1716 : GraphNode(sgt_, Op_RFFT2D, id_)
1717{
1718 setRequiredOperands(1, 2);
1719 setRequiredRank(3);
1720}
1721
1722template <DType Dtype>
1723OpRFFT2d<Dtype>::~OpRFFT2d() {}
1724
1725
1726template <DType Dtype>
1727int OpRFFT2d<Dtype>::checkTensorAttributes()
1728{
1729 if (validateRequiredOperands())
1730 return 1;
1731
1732 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]) ||
1733 validateRequiredRank(outputs[1]))
1734 {
1735 return 1;
1736 }
1737
1738 if (inputs[0]->matchType(*outputs[0]) || inputs[0]->matchType(*outputs[1]))
1739 {
1740 printNodeValidationError("OpRFFT2d: input and output tensor type mismatch");
1741 return 1;
1742 }
1743
1744 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
1745 out_real = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
1746 out_imag = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[1]);
1747
1748 ASSERT_MEM(in && out_real && out_imag);
1749
Luke Hutton57287132023-02-06 14:54:18 +00001750 std::string msg;
1751 if (check_fft_shape(in->getShape(), {},
1752 out_real->getShape(), out_imag->getShape(), msg))
Luke Hutton261b7b62023-01-10 14:50:31 +00001753 {
Luke Hutton57287132023-02-06 14:54:18 +00001754 msg = "OpRFFT2d: " + msg;
1755 printNodeValidationError(msg.c_str());
Luke Hutton261b7b62023-01-10 14:50:31 +00001756 return 1;
1757 }
1758
1759 return 0;
1760}
1761
1762template <DType Dtype>
1763int OpRFFT2d<Dtype>::eval()
1764{
1765 int32_t in_batch = in->getShape()[0];
1766 int32_t in_height = in->getShape()[1];
1767 int32_t in_width = in->getShape()[2];
1768
1769 int32_t out_real_batch = out_real->getShape()[0];
1770 int32_t out_real_height = out_real->getShape()[1];
1771 int32_t out_real_width = out_real->getShape()[2];
1772
1773 int32_t out_imag_batch = out_imag->getShape()[0];
1774 int32_t out_imag_height = out_imag->getShape()[1];
1775 int32_t out_imag_width = out_imag->getShape()[2];
1776
Jerry Gea793f462023-04-11 00:05:02 +00001777 // Check Tosa Level
1778 auto tosa_level = g_func_config.tosa_level;
1779 LEVEL_CHECK(in_height <= tosa_level.MAX_KERNEL, "H should be smaller than or equal to MAX_KERNEL");
1780 LEVEL_CHECK(in_width <= tosa_level.MAX_KERNEL, "W should be smaller than or equal to MAX_KERNEL");
1781
Luke Hutton261b7b62023-01-10 14:50:31 +00001782 DEBUG_INFO(OP,
1783 "perform OpRFFT2d, input.shape=[%d,%d,%d], output_real.shape=[%d,%d,%d], "
1784 "output_imag.shape=[%d,%d,%d]",
1785 in_batch, in_height, in_width,
1786 out_real_batch, out_real_height, out_real_width,
1787 out_imag_batch, out_imag_height, out_imag_width);
1788
1789 OutEigenType sum_real, sum_imag, a;
1790
1791 for (int n = 0; n < in_batch; n++)
1792 {
1793 for (int oy = 0; oy < out_real_height; oy++)
1794 {
1795 for (int ox = 0; ox < out_real_width; ox++)
1796 {
1797 sum_real = 0.0;
1798 sum_imag = 0.0;
1799 for (int iy = 0; iy < in_height; iy++)
1800 {
1801 for (int ix = 0; ix < in_width; ix++)
1802 {
1803 // Use explicit cast to ensure intermmediate calculations are completed using OutEigenType
1804 a = 2 * M_PI * ((iy * (OutEigenType)oy) / in_height + (ix * (OutEigenType)ox) / in_width);
1805 sum_real += this->in->getTensor()(n, iy, ix) * cos(a);
1806 sum_imag += -this->in->getTensor()(n, iy, ix) * sin(a);
1807 }
1808 }
1809 this->out_real->getTensor()(n, oy, ox) = sum_real;
1810 this->out_imag->getTensor()(n, oy, ox) = sum_imag;
1811 }
1812 }
1813 }
1814
1815 return GraphNode::eval();
1816}
1817
James Wardd34b3fc2023-01-18 14:51:25 +00001818template <DType InDtype, DType WeightDtype, DType OutDtype>
1819OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::OpTransposeConv2d(SubgraphTraverser* sgt_,
Kevin Chengcc61be32021-10-14 17:09:57 -07001820 TosaAttributeBase* attribute_,
Kevin Chengcc61be32021-10-14 17:09:57 -07001821 uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -07001822 : GraphNode(sgt_, Op_TRANSPOSE_CONV2D, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -07001823{
1824 setRequiredOperands(3, 1);
1825 setRequiredRank(4);
1826
Kevin Cheng93a16282021-08-31 16:14:03 -07001827 INIT_ATTRIBUTE(TransposeConv);
Eric Kunzee5e26762020-10-13 16:11:07 -07001828}
1829
James Wardd34b3fc2023-01-18 14:51:25 +00001830template <DType InDtype, DType WeightDtype, DType OutDtype>
1831OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::~OpTransposeConv2d()
Eric Kunzee5e26762020-10-13 16:11:07 -07001832{
1833 if (attribute)
1834 delete attribute;
Eric Kunzee5e26762020-10-13 16:11:07 -07001835}
1836
James Wardd34b3fc2023-01-18 14:51:25 +00001837template <DType InDtype, DType WeightDtype, DType OutDtype>
1838int OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -07001839{
1840 if (validateRequiredOperands())
1841 return 1;
1842
1843 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
1844 {
1845 return 1;
1846 }
1847
James Wardd34b3fc2023-01-18 14:51:25 +00001848 ERROR_IF(outputs[0]->getDtype() != OutDtype,
James Ward8b390432022-08-12 20:48:56 +01001849 "OpTransposeConv2d: Output data type not supported for this configuration of operator");
Kevin Chengcc61be32021-10-14 17:09:57 -07001850
Eric Kunzee5e26762020-10-13 16:11:07 -07001851 input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
1852 weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
1853 bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
James Ward8b390432022-08-12 20:48:56 +01001854 output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
Eric Kunzee5e26762020-10-13 16:11:07 -07001855
TatWai Chong24594f52022-06-08 00:48:04 -07001856 if (attribute->out_pad().size() != 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07001857 {
TatWai Chong24594f52022-06-08 00:48:04 -07001858 printNodeValidationError("OpTransposeConv2d: illegal size for attribute out_pad");
Eric Kunzee5e26762020-10-13 16:11:07 -07001859 return 1;
1860 }
1861
1862 if (attribute->stride().size() != 2)
1863 {
1864 printNodeValidationError("OpTransposeConv2d: illegal size for attribute stride");
1865 return 1;
1866 }
1867
Eric Kunzee5e26762020-10-13 16:11:07 -07001868 if (attribute->output_shape().size() != 4)
1869 {
1870 printNodeValidationError("OpTransposeConv2d: illegal size for attribute output_shape");
1871 return 1;
1872 }
1873
Eric Kunzec1a97832022-07-01 16:56:09 -07001874
Kevin Cheng9fe17242021-11-10 01:04:39 +00001875
1876 for (int32_t i : attribute->stride())
1877 {
1878 if (i < 1)
1879 {
1880 printNodeValidationError("OpTransposeConv2d: At least one stride is smaller than one");
1881 return 1;
1882 }
1883 }
1884
Eric Kunzee5e26762020-10-13 16:11:07 -07001885 for (int d = 0; d < 4; d++)
1886 {
1887 if (attribute->output_shape()[d] != this->output->getShape()[d])
1888 {
1889 printNodeValidationError("OpTransposeConv2d: illegal size for attribute output_shape");
1890 return 1;
1891 }
1892 }
1893
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001894 int32_t IH = input->getShape()[1];
1895 int32_t IW = input->getShape()[2];
1896 int32_t OH = output->getShape()[1];
1897 int32_t OW = output->getShape()[2];
1898
1899 int32_t stride_y = attribute->stride()[0];
1900 int32_t stride_x = attribute->stride()[1];
1901 int32_t kernel_h = weight->getShape()[1];
1902 int32_t kernel_w = weight->getShape()[2];
1903
TatWai Chong24594f52022-06-08 00:48:04 -07001904 int32_t out_pad_top = attribute->out_pad()[0];
1905 int32_t out_pad_bottom = attribute->out_pad()[1];
1906 int32_t out_pad_left = attribute->out_pad()[2];
1907 int32_t out_pad_right = attribute->out_pad()[3];
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001908
Eric Kunzec1a97832022-07-01 16:56:09 -07001909 for (size_t i = 0; i < attribute->out_pad().size(); i++)
1910 {
1911 ERROR_IF(attribute->out_pad()[i] <= -(weight->getShape()[(i / 2) + 1]), "OpTransposeConv2d: At least one out_pad value is larger than kernel size");
1912 }
1913
1914 int32_t H = (IH - 1) * stride_y + out_pad_top + out_pad_bottom + kernel_h;
1915 int32_t W = (IW - 1) * stride_x + out_pad_left + out_pad_right + kernel_w;
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001916
1917 if ((OH != H) || (OW != W))
1918 {
1919 std::string msg = "OpTransposeConv2d: Mismatch between output shape provided and expected output shape (" +
1920 std::to_string(H) + "," +
1921 std::to_string(W) + ")";
1922 printNodeValidationError(msg.c_str());
1923 return 1;
1924 }
1925
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001926 ERROR_IF(InDtype != DType_INT8 && attribute->input_zp() != 0, "OpTransposeConv2d: Input zeropoint must be zero for non int8_t data");
1927 ERROR_IF(WeightDtype != DType_INT8 && attribute->weight_zp() != 0, "OpTransposeConv2d: Weight zeropoint must be zero for non int8_t data");
Kevin Chengcc61be32021-10-14 17:09:57 -07001928
Eric Kunzee5e26762020-10-13 16:11:07 -07001929 return 0;
1930}
1931
James Wardd34b3fc2023-01-18 14:51:25 +00001932template <DType InDtype, DType WeightDtype, DType OutDtype>
1933int OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -07001934{
1935 int in_batch = this->input->getShape()[0];
1936 int in_height = this->input->getShape()[1];
1937 int in_width = this->input->getShape()[2];
1938 int in_channels = this->input->getShape()[3];
1939
1940 int f_out_channels = this->weight->getShape()[0];
1941 int f_height = this->weight->getShape()[1];
1942 int f_width = this->weight->getShape()[2];
1943 int f_in_channels = this->weight->getShape()[3];
1944
1945 int b_out_channels = this->bias->getShape()[0];
1946
1947 int out_batch = this->output->getShape()[0];
1948 int out_height = this->output->getShape()[1];
1949 int out_width = this->output->getShape()[2];
1950 int out_channels = this->output->getShape()[3];
1951
TatWai Chong24594f52022-06-08 00:48:04 -07001952 int out_pad_top = this->attribute->out_pad()[0];
1953 int out_pad_bottom = this->attribute->out_pad()[1];
1954 int out_pad_left = this->attribute->out_pad()[2];
1955 int out_pad_right = this->attribute->out_pad()[3];
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001956
Jerry Gea793f462023-04-11 00:05:02 +00001957 int stride_y = this->attribute->stride()[0];
1958 int stride_x = this->attribute->stride()[1];
Eric Kunzee5e26762020-10-13 16:11:07 -07001959
Kevin Chengacb550f2021-06-29 15:32:19 -07001960 ERROR_IF(in_batch != out_batch, "OpTransposeConv2d: tensor batch mismatch %d != %d", in_batch, out_batch);
1961 ERROR_IF(f_in_channels != in_channels, "OpTransposeConv2d: tensor input channel mismatch %d != %d", f_in_channels,
1962 in_channels);
1963 ERROR_IF(f_out_channels != out_channels, "OpTransposeConv2d: tensor output channel mismatch %d != %d",
1964 f_out_channels, out_channels);
1965 ERROR_IF(b_out_channels != out_channels, "OpDepthwiseConv2d: bias channels mismatch %d != %d", b_out_channels,
1966 out_channels);
Eric Kunzee5e26762020-10-13 16:11:07 -07001967
Jerry Gea793f462023-04-11 00:05:02 +00001968 // Check Tosa Level
1969 auto tosa_level = g_func_config.tosa_level;
1970 LEVEL_CHECK(f_height <= tosa_level.MAX_KERNEL, "KH should be smaller than or equal to MAX_KERNEL");
1971 LEVEL_CHECK(f_width <= tosa_level.MAX_KERNEL, "KW should be smaller than or equal to MAX_KERNEL");
1972 LEVEL_CHECK(out_pad_top <= tosa_level.MAX_KERNEL, "out_pad_top should be smaller than or equal to MAX_KERNEL");
1973 LEVEL_CHECK(out_pad_bottom <= tosa_level.MAX_KERNEL, "out_pad_bottom should be smaller than or equal to MAX_KERNEL");
1974 LEVEL_CHECK(out_pad_left <= tosa_level.MAX_KERNEL, "out_pad_left should be smaller than or equal to MAX_KERNEL");
1975 LEVEL_CHECK(out_pad_right <= tosa_level.MAX_KERNEL, "out_pad_right should be smaller than or equal to MAX_KERNEL");
1976 LEVEL_CHECK(stride_y <= tosa_level.MAX_STRIDE, "stride_y should be smaller than or equal to MAX_STRIDE");
1977 LEVEL_CHECK(stride_x <= tosa_level.MAX_STRIDE, "stride_x should be smaller than or equal to MAX_STRIDE");
1978
Eric Kunzee5e26762020-10-13 16:11:07 -07001979 DEBUG_INFO(OP,
1980 "perform OpTransposeConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], "
James Wardd34b3fc2023-01-18 14:51:25 +00001981 "output.shape=[%d,%d,%d,%d], stride=[%d,%d], out_pad=[%d,%d,%d,%d]",
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001982 in_batch, in_height, in_width, in_channels, f_height, f_width, f_out_channels, f_in_channels,
Jerry Gea793f462023-04-11 00:05:02 +00001983 out_batch, out_height, out_width, out_channels, stride_y, stride_x, out_pad_top,
James Wardd34b3fc2023-01-18 14:51:25 +00001984 out_pad_bottom, out_pad_left, out_pad_right);
Eric Kunzee5e26762020-10-13 16:11:07 -07001985
1986 TIn input_val = this->input->getTensor();
1987 TWeight weight_val = this->weight->getTensor();
Eric Kunzef7337832022-06-17 08:19:12 -07001988 if (InDtype == DType_INT8 || WeightDtype == DType_INT8)
Eric Kunzee5e26762020-10-13 16:11:07 -07001989 {
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001990 input_val = input_val - (InEigenType)attribute->input_zp();
1991 weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
Eric Kunzee5e26762020-10-13 16:11:07 -07001992 }
1993
1994 Eigen::array<Eigen::Index, 4> reshape_dim;
1995 reshape_dim.fill(1);
1996 reshape_dim[3] = b_out_channels;
1997
1998 Eigen::array<Eigen::Index, 4> bcast;
1999 bcast[0] = out_batch;
2000 bcast[1] = out_height;
2001 bcast[2] = out_width;
2002 bcast[3] = 1;
2003
2004 // initialize with bias
2005 this->output->getTensor() = this->bias->getTensor().reshape(reshape_dim).broadcast(bcast);
2006
2007 int out_x_origin, out_y_origin;
2008 int out_x, out_y;
2009
2010 // reference implementation from: tensorflow/tensorflow/lite/kernels/internal/reference/reference_ops.h
2011 for (int ob = 0; ob < out_batch; ob++)
2012 {
2013 for (int ih = 0; ih < in_height; ih++)
2014 {
2015 for (int iw = 0; iw < in_width; iw++)
2016 {
Jerry Gea793f462023-04-11 00:05:02 +00002017 out_x_origin = iw * stride_x + out_pad_left;
2018 out_y_origin = ih * stride_y + out_pad_top;
Eric Kunzee5e26762020-10-13 16:11:07 -07002019 for (int ic = 0; ic < in_channels; ic++)
2020 {
2021 for (int fh = 0; fh < f_height; fh++)
2022 {
2023 for (int fw = 0; fw < f_width; fw++)
2024 {
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002025 out_x = out_x_origin + fw;
2026 out_y = out_y_origin + fh;
Eric Kunzee5e26762020-10-13 16:11:07 -07002027 for (int oc = 0; oc < out_channels; oc++)
2028 {
2029 if ((out_x >= 0 && out_x < out_width) && (out_y >= 0 && out_y < out_height))
2030 {
2031 this->output->getTensor()(ob, out_y, out_x, oc) +=
James Ward8b390432022-08-12 20:48:56 +01002032 (OutEigenType) ((AccEigenType)input_val(ob, ih, iw, ic) *
2033 (AccEigenType)weight_val(oc, fh, fw, ic));
Eric Kunzee5e26762020-10-13 16:11:07 -07002034 }
2035 }
2036 }
2037 }
2038 }
2039 }
2040 }
2041 }
2042
James Wardd34b3fc2023-01-18 14:51:25 +00002043 if (OutDtype == DType_INT48)
Eric Kunzee5e26762020-10-13 16:11:07 -07002044 {
James Ward8b390432022-08-12 20:48:56 +01002045 this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
2046 this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
Eric Kunzee5e26762020-10-13 16:11:07 -07002047 }
2048
2049 return GraphNode::eval();
2050}
2051
2052// template explicit instantiation
James Ward8b390432022-08-12 20:48:56 +01002053DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP16);
James Ward24dbc422022-10-19 12:20:31 +01002054DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002055DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP32);
Kevin Cheng3a478572021-01-22 17:21:02 -08002056DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -07002057DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16);
2058
James Wardd34b3fc2023-01-18 14:51:25 +00002059DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP16, FP16);
2060DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP16, FP32);
2061DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, BF16, FP32);
2062DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP32, FP32);
2063DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, INT8, INT32);
2064DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, INT16, INT32);
Eric Kunzee5e26762020-10-13 16:11:07 -07002065
James Wardd34b3fc2023-01-18 14:51:25 +00002066 // [in_t, weight_t, out_t]
2067DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP16, FP16, FP16);
2068DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP16, FP16, FP32);
2069DEF_INSTANTIATE_THREE_TYPE(OpConv2d, BF16, BF16, FP32);
2070DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP32, FP32, FP32);
2071DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT8, INT4, INT32);
2072DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT8, INT8, INT32);
2073DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT16, INT8, INT48);
Eric Kunzee5e26762020-10-13 16:11:07 -07002074
James Wardd34b3fc2023-01-18 14:51:25 +00002075DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP16, FP16, FP16);
2076DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP16, FP16, FP32);
2077DEF_INSTANTIATE_THREE_TYPE(OpConv3d, BF16, BF16, FP32);
2078DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP32, FP32, FP32);
2079DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT8, INT4, INT32);
2080DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT8, INT8, INT32);
2081DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT16, INT8, INT48);
Kevin Cheng1533b852021-09-01 12:51:58 -07002082
James Wardd34b3fc2023-01-18 14:51:25 +00002083DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP16);
2084DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP32);
2085DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, BF16, BF16, FP32);
2086DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP32, FP32, FP32);
2087DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT8, INT4, INT32);
2088DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT8, INT8, INT32);
2089DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT16, INT8, INT48);
Eric Kunzee5e26762020-10-13 16:11:07 -07002090
Luke Hutton57287132023-02-06 14:54:18 +00002091DEF_INSTANTIATE_ONE_TYPE(OpFFT2d, FP32);
2092
James Wardd34b3fc2023-01-18 14:51:25 +00002093DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP16, FP16, FP16);
2094DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP16, FP16, FP32);
2095DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, BF16, BF16, FP32);
2096DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP32, FP32, FP32);
2097DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, INT8, INT4, INT32);
2098DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, INT8, INT8, INT32);
2099DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, INT16, INT8, INT48);
Eric Kunzee5e26762020-10-13 16:11:07 -07002100
James Wardd34b3fc2023-01-18 14:51:25 +00002101DEF_INSTANTIATE_TWO_TYPE(OpMatMul, INT8, INT32);
2102DEF_INSTANTIATE_TWO_TYPE(OpMatMul, INT16, INT48);
2103DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP16, FP16);
2104DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP16, FP32);
2105DEF_INSTANTIATE_TWO_TYPE(OpMatMul, BF16, FP32);
2106DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP32, FP32);
Eric Kunzee5e26762020-10-13 16:11:07 -07002107
James Ward8b390432022-08-12 20:48:56 +01002108DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP16);
James Ward24dbc422022-10-19 12:20:31 +01002109DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002110DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP32);
Kevin Cheng3a478572021-01-22 17:21:02 -08002111DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -07002112DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT16);
2113
Luke Hutton261b7b62023-01-10 14:50:31 +00002114DEF_INSTANTIATE_ONE_TYPE(OpRFFT2d, FP32);
2115
James Wardd34b3fc2023-01-18 14:51:25 +00002116DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP16);
2117DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP32);
2118DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, BF16, BF16, FP32);
2119DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP32, FP32, FP32);
2120DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT8, INT4, INT32);
2121DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT8, INT8, INT32);
2122DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT16, INT8, INT48);