blob: be4e4aa253f50ef505cfc4b9e61ae808e1ea0c47 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Kevin Cheng3a478572021-01-22 17:21:02 -08002// Copyright (c) 2020-2021, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "tensor_ops.h"
17#include "quant_util.h"
18#include "template_types.h"
19
20using namespace TosaReference;
21using namespace Eigen;
22using namespace tosa;
23
Kevin Cheng7eb93d72021-10-09 01:26:08 +000024int check_pool2d_attribute_common(tosa::TosaPoolAttribute* attribute,
25 std::vector<int32_t> input_shape,
26 std::vector<int32_t> output_shape,
27 std::string& msg)
28{
29 if (attribute->padding().size() != 4)
30 {
31 msg = "illegal size for attribute padding";
32 return 1;
33 }
34
35 if (attribute->kernel().size() != 2)
36 {
37 msg = "illegal size for attribute kernel";
38 return 1;
39 }
40
41 if (attribute->stride().size() != 2)
42 {
43 msg = "illegal size for attribute stride";
44 return 1;
45 }
46
47 for (int32_t i : attribute->padding())
48 {
49 if (i < 0)
50 {
51 msg = "At least one pad is smaller than zero";
52 return 1;
53 }
54 }
55
56 for (int32_t i : attribute->kernel())
57 {
58 if (i < 1)
59 {
60 msg = "At least one kernel dimension is smaller than zero";
61 return 1;
62 }
63 }
64
65 for (int32_t i : attribute->stride())
66 {
67 if (i < 1)
68 {
69 msg = "At least one stride dimension is smaller than zero";
70 return 1;
71 }
72 }
73
74 int32_t IH = input_shape[1];
75 int32_t IW = input_shape[2];
76 int32_t OH = output_shape[1];
77 int32_t OW = output_shape[2];
78
79 int32_t pad_top = attribute->padding()[0];
80 int32_t pad_bottom = attribute->padding()[1];
81 int32_t pad_left = attribute->padding()[2];
82 int32_t pad_right = attribute->padding()[3];
83
84 int32_t stride_y = attribute->stride()[0];
85 int32_t stride_x = attribute->stride()[1];
86 int32_t kernel_y = attribute->kernel()[0];
87 int32_t kernel_x = attribute->kernel()[1];
88
89 if (pad_top >= kernel_y || pad_bottom >= kernel_y || pad_left >= kernel_x || pad_right >= kernel_x)
90 {
91 msg = "At least one pad is >= kernel dimension";
92 return 1;
93 }
94
95 int32_t allowed_min_input_height = (OH * stride_y) - pad_top - pad_bottom - stride_y + kernel_y;
96 int32_t allowed_min_input_width = (OW * stride_x) - pad_left - pad_right - stride_x + kernel_x;
97
98 int32_t d_height = IH - allowed_min_input_height;
99 int32_t d_width = IW - allowed_min_input_width;
100
101 if (d_height < 0 || d_height > stride_y || d_width < 0 || d_width > stride_x)
102 {
103 msg = "Mismatch between output shape provided and expected output shape";
104 return 1;
105 }
106
107 return 0;
108}
109
Eric Kunzee5e26762020-10-13 16:11:07 -0700110template <int Rank, DType Dtype>
Kevin Chengacb550f2021-06-29 15:32:19 -0700111OpArgMax<Rank, Dtype>::OpArgMax(SubgraphTraverser* sgt_,
112 TosaAttributeBase* attribute_,
113 TosaQuantInfoBase* qinfo_,
114 uint64_t id_)
115 : GraphNode(sgt_, Op_ARGMAX, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700116{
117 setRequiredOperands(1, 1);
Kevin Chengcc61be32021-10-14 17:09:57 -0700118 setRequiredRank(1, 4);
Eric Kunzee5e26762020-10-13 16:11:07 -0700119
120 INIT_ATTRIBUTE(Axis);
121}
122
123template <int Rank, DType Dtype>
124OpArgMax<Rank, Dtype>::~OpArgMax()
125{
126 if (attribute)
127 delete attribute;
128}
129
130template <int Rank, DType Dtype>
131int OpArgMax<Rank, Dtype>::checkTensorAttributes()
132{
133 if (validateRequiredOperands())
134 return 1;
135
Kevin Chengcc61be32021-10-14 17:09:57 -0700136 if (validateRequiredRank(inputs[0]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700137 {
138 return 1;
139 }
140
Kevin Chengcc61be32021-10-14 17:09:57 -0700141 int32_t output_rank = inputs[0]->getRank() - 1;
142 if (output_rank != outputs[0]->getRank())
143 {
144 printNodeValidationError("OpArgMax: Output rank needs to be rank(input) - 1");
145 return 1;
146 }
147
148 if (outputs[0]->getDtype() != DType_INT32)
149 {
150 printNodeValidationError("OpArgMax: Output data type not supported for this configuration of operator");
151 return 1;
152 }
153
Eric Kunzee5e26762020-10-13 16:11:07 -0700154 input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
155 output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
156
Kevin Chengcc61be32021-10-14 17:09:57 -0700157 if (attribute->axis() < 0 || attribute->axis() >= input->getRank())
158 {
159 printNodeValidationError("OpArgMax: Axis needs to be within [0, rank(input)]");
160 return 1;
161 }
162
163 bool shape_check = true;
164 for (int32_t i = 0; i < input->getRank(); i++)
165 {
166 if (i < attribute->axis())
167 {
168 if (input->getShape()[i] != output->getShape()[i])
169 {
170 shape_check = false;
171 break;
172 }
173 }
174 else if (i > attribute->axis())
175 {
176 if (input->getShape()[i] != output->getShape()[i - 1])
177 {
178 shape_check = false;
179 break;
180 }
181 }
182 // No need to check i == axis
183 }
184 if (!shape_check)
185 {
186 printNodeValidationError("OpArgMax: Mismatch between output shape provided and expected output shape");
187 return 1;
188 }
189
Eric Kunzee5e26762020-10-13 16:11:07 -0700190 return 0;
191}
192
193template <int Rank, DType Dtype>
194int OpArgMax<Rank, Dtype>::eval()
195{
196 Eigen::Tensor<DenseIndex, Rank - 1> index = this->input->getTensor().argmax(attribute->axis());
197
198 this->output->getTensor() = index.unaryExpr([](DenseIndex in) -> OutEigenType { return (OutEigenType)in; });
199
200 return GraphNode::eval();
201}
202
203template <DType Dtype>
Kevin Chengacb550f2021-06-29 15:32:19 -0700204OpAvgPool2d<Dtype>::OpAvgPool2d(SubgraphTraverser* sgt_,
205 TosaAttributeBase* attribute_,
206 TosaQuantInfoBase* qinfo_,
207 uint64_t id_)
208 : GraphNode(sgt_, Op_AVG_POOL2D, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700209{
210 setRequiredOperands(1, 1);
211 setRequiredRank(4);
212
Kevin Cheng93a16282021-08-31 16:14:03 -0700213 INIT_ATTRIBUTE(Pool);
Eric Kunzee5e26762020-10-13 16:11:07 -0700214 INIT_QINFO(Unary);
215}
216
217template <DType Dtype>
218OpAvgPool2d<Dtype>::~OpAvgPool2d()
219{
220 if (attribute)
221 delete attribute;
222}
223
224template <DType Dtype>
225int OpAvgPool2d<Dtype>::checkTensorAttributes()
226{
227 if (validateRequiredOperands())
228 return 1;
229
230 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
231 {
232 return 1;
233 }
234
235 if (inputs[0]->matchType(*outputs[0]))
236 {
237 printNodeValidationError("OpAvgPool2d: input and output tensor type mismatch");
238 return 1;
239 }
240
241 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
242 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
243
Kevin Cheng7eb93d72021-10-09 01:26:08 +0000244 if (Dtype != DType_INT8 && this->qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -0700245 {
Kevin Cheng7eb93d72021-10-09 01:26:08 +0000246 ERROR_IF(this->qinfo->input_zp() != 0, "OpAvgPool2d: zeropoint only for int8_t");
247 ERROR_IF(this->qinfo->output_zp() != 0, "OpAvgPool2d: zeropoint only for int8_t");
Eric Kunzee5e26762020-10-13 16:11:07 -0700248 }
249
Kevin Cheng7eb93d72021-10-09 01:26:08 +0000250 std::string msg;
251 if (check_pool2d_attribute_common(attribute, in->getShape(), out->getShape(), msg))
Eric Kunzee5e26762020-10-13 16:11:07 -0700252 {
Kevin Cheng7eb93d72021-10-09 01:26:08 +0000253 msg = "OpAvgPool2d: " + msg;
254 printNodeValidationError(msg.c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -0700255 return 1;
256 }
257
258 return 0;
259}
260
261template <DType Dtype>
262ETensor1<int32_t> OpAvgPool2d<Dtype>::calculate_div_map_1d(int in_size, int out_size, int kernel_size, int stride)
263{
264 ETensor1<int32_t> result(out_size);
265
266 int32_t total_pad = (out_size - 1) * stride + kernel_size - in_size;
267 total_pad = total_pad < 0 ? 0 : total_pad;
268
269 int32_t pad_left = total_pad >> 1;
270 int32_t pad_right = total_pad - pad_left;
271
272 result.setConstant(kernel_size);
273
274 // the index left to 'left_index' and index right to 'right_index' indicates
275 // the input window of this output covers a pad bit
276 int32_t left_index = pad_left / stride;
277 int32_t right_index = pad_right / stride;
278
Eric Kunzee5e26762020-10-13 16:11:07 -0700279 // minus the number of pad bit this index cover
280 while (left_index >= 0)
281 {
282 result(left_index) -= (pad_left - left_index * stride);
283 left_index--;
284 }
285
286 while (right_index >= 0)
287 {
288 result(out_size - 1 - right_index) -= (pad_right - right_index * stride);
289 right_index--;
290 }
291
292 return result;
293}
294
295// assuming input and output tensor have same scales like tflite reference
296// so no need to scale input and output
297template <DType Dtype>
298int OpAvgPool2d<Dtype>::eval()
299{
300 int in_batch = this->in->getShape()[0];
301 int in_height = this->in->getShape()[1];
302 int in_width = this->in->getShape()[2];
303 int in_channels = this->in->getShape()[3];
304
305 int out_batch = this->out->getShape()[0];
306 int out_height = this->out->getShape()[1];
307 int out_width = this->out->getShape()[2];
308 int out_channels = this->out->getShape()[3];
309
Kevin Chengacb550f2021-06-29 15:32:19 -0700310 ERROR_IF(in_batch != out_batch, "OpAvgPool2d: tensor batch mismatch %d != %d", in_batch, out_batch);
311 ERROR_IF(in_channels != out_channels, "OpAvgPool2d: tensor channel mismatch %d != %d", in_channels, out_channels);
Eric Kunzee5e26762020-10-13 16:11:07 -0700312
313 int padding_top = this->attribute->padding()[0];
314 int padding_bottom = this->attribute->padding()[1];
315 int padding_left = this->attribute->padding()[2];
316 int padding_right = this->attribute->padding()[3];
317 int kernel_h = this->attribute->kernel()[0];
318 int kernel_w = this->attribute->kernel()[1];
319 int stride_h = this->attribute->stride()[0];
320 int stride_w = this->attribute->stride()[1];
321
322 DEBUG_INFO(OP,
323 "perform AvgPool2d, input.shape=[%d,%d,%d,%d], output.shape=[%d,%d,%d,%d], kernel=[%d,%d], "
324 "stride=[%d,%d], padding=[%d,%d,%d,%d]",
325 in_batch, in_height, in_width, in_channels, out_batch, out_height, out_width, out_channels, kernel_h,
326 kernel_w, stride_h, stride_w, padding_top, padding_bottom, padding_left, padding_right);
327
328 Eigen::array<Eigen::Index, 2> im2col_input_dims;
329 im2col_input_dims[0] = kernel_h * kernel_w;
330 im2col_input_dims[1] = out_batch * out_height * out_width * out_channels;
331
332 Eigen::array<Eigen::Index, 4> col2im_output_dims;
333 col2im_output_dims[0] = out_batch;
334 col2im_output_dims[1] = out_height;
335 col2im_output_dims[2] = out_width;
336 col2im_output_dims[3] = out_channels;
337
338 Eigen::array<std::pair<int32_t, int32_t>, 4> padding;
339 padding[0] = std::make_pair(0, 0);
340 padding[1] = std::make_pair(padding_top, padding_bottom);
341 padding[2] = std::make_pair(padding_left, padding_right);
342 padding[3] = std::make_pair(0, 0);
343
344 ETensor4<InEigenType> input_val = this->in->getTensor();
345 if (this->qinfo)
346 {
347 input_val = input_val - (InEigenType)this->qinfo->input_zp();
348 }
349
350 ETensor4<InEigenType> input_padded = input_val.pad(padding);
351
352 // assuming input and output have same scales
353 // so input and output scaling is not required
354 // TODO: check if this assumption TOSA made
355
356 // extract_image_patches() output [N, KH, KW, H * W, C]
357 // transpose to [KH, KW, N, H * W, C]
358 // reshape to [KH * KW, N * H * W * C]
359 ETensor2<InEigenType> input_extract_patches =
360 input_padded.extract_image_patches(kernel_h, kernel_w, stride_h, stride_w, 1, 1, Eigen::PADDING_VALID)
361 .shuffle(Eigen::array<Eigen::Index, 5>{ 1, 2, 0, 3, 4 })
362 .reshape(im2col_input_dims);
363
364 // 1D result with [N * H * W * C]
365 ETensor1<AccEigenType> out_1d(this->out->getElementCount());
366 out_1d.setZero();
367
368 // sum pool
369 for (size_t i = 0; i < this->out->getElementCount(); i++)
370 {
371 for (int32_t j = 0; j < kernel_h * kernel_w; j++)
372 {
373 out_1d(i) += (AccEigenType)input_extract_patches(j, i);
374 }
375 }
376
377 // reshape result to [N, H, W, C] and divide with div_map
378 ETensor4<AccEigenType> sum = out_1d.reshape(col2im_output_dims);
379
380 // calculate 1d height/width div_map (number of elements this pooling window covers)
381 // and outer product to get 2d div_map, then reshape/broadcast to [N, H, W, C]
382 ETensor1<int32_t> div_map_h = calculate_div_map_1d(in_height, out_height, kernel_h, stride_h);
383 ETensor1<int32_t> div_map_w = calculate_div_map_1d(in_width, out_width, kernel_w, stride_w);
384 Eigen::array<Eigen::IndexPair<Eigen::Index>, 1> contract_dims = { Eigen::IndexPair<Eigen::Index>(1, 0) };
385 Eigen::array<Eigen::Index, 4> bcast{ out_batch, 1, 1, out_channels };
386
387 ETensor4<int32_t> div_map =
388 div_map_h.reshape(Eigen::array<Eigen::Index, 2>{ out_height, 1 })
389 .contract(div_map_w.reshape(Eigen::array<Eigen::Index, 2>{ 1, out_width }), contract_dims)
390 .reshape(Eigen::array<Eigen::Index, 4>{ 1, out_height, out_width, 1 })
391 .broadcast(bcast);
392
393 if (Dtype != DType_FLOAT)
394 {
Kevin Chengacb550f2021-06-29 15:32:19 -0700395 try
396 {
397 this->out->getTensor() = sum.binaryExpr(div_map, [](AccEigenType value, int32_t div) -> OutEigenType {
398 int32_t multiplier, shift;
399 TosaReference::QuantUtil::reciprocal_scale(div, multiplier, shift);
Eric Kunzee5e26762020-10-13 16:11:07 -0700400
Kevin Chengacb550f2021-06-29 15:32:19 -0700401 return (OutEigenType)TosaReference::QuantUtil::apply_scale_32(value, multiplier, shift, false);
402 });
403 }
404 catch (std::string desc)
405 {
406 REQUIRE(false, "OpAvgPool2d apply_scale_32() fails: %s.", desc.c_str());
407 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700408 this->out->getTensor() = this->out->getTensor() + (OutEigenType)(this->qinfo->output_zp());
409 this->out->getTensor() = this->out->getTensor().cwiseMax((OutEigenType)QMin);
410 this->out->getTensor() = this->out->getTensor().cwiseMin((OutEigenType)QMax);
411 }
412 else
413 {
414 this->out->getTensor() = (sum / div_map.template cast<AccEigenType>()).template cast<OutEigenType>();
415 }
416
417 return GraphNode::eval();
418}
419
420template <DType InDtype, DType WeightDtype>
Kevin Chengacb550f2021-06-29 15:32:19 -0700421OpConv2d<InDtype, WeightDtype>::OpConv2d(SubgraphTraverser* sgt_,
422 TosaAttributeBase* attribute_,
423 TosaQuantInfoBase* qinfo_,
424 uint64_t id_)
425 : GraphNode(sgt_, Op_CONV2D, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700426{
427 setRequiredOperands(3, 1);
428 setRequiredRank(4);
429
Kevin Cheng93a16282021-08-31 16:14:03 -0700430 INIT_ATTRIBUTE(Conv);
Eric Kunzee5e26762020-10-13 16:11:07 -0700431 INIT_QINFO(Conv);
432}
433
434template <DType InDtype, DType WeightDtype>
435OpConv2d<InDtype, WeightDtype>::~OpConv2d()
436{
437 if (attribute)
438 delete attribute;
439 if (qinfo)
440 delete qinfo;
441}
442
443template <DType InDtype, DType WeightDtype>
444int OpConv2d<InDtype, WeightDtype>::checkTensorAttributes()
445{
446 if (validateRequiredOperands())
447 return 1;
448
449 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
450 {
451 return 1;
452 }
453
454 // 'bias' checked separatedly since it doens't make sense to make required rank ranging from 1 to 4
455 if (inputs[2]->getRank() != 1)
456 {
457 printNodeValidationError("OpConv2d: bias tensor must be rank 1");
458 }
459
Kevin Chengcc61be32021-10-14 17:09:57 -0700460 ERROR_IF(outputs[0]->getDtype() != AccDtype,
461 "OpFullyConnected: Output data type not supported for this configuration of operator");
462
Eric Kunzee5e26762020-10-13 16:11:07 -0700463 input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
464 weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
465 bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
466 output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]);
467
Eric Kunzee5e26762020-10-13 16:11:07 -0700468 if (attribute->padding().size() != 4)
469 {
470 printNodeValidationError("OpConv2d: illegal size for attribute padding");
471 return 1;
472 }
473
474 if (attribute->stride().size() != 2)
475 {
476 printNodeValidationError("OpConv2d: illegal size for attribute stride");
477 return 1;
478 }
479
480 if (attribute->dilation().size() != 2)
481 {
482 printNodeValidationError("OpConv2d: illegal size for attribute dilation");
483 return 1;
484 }
485
Kevin Chengcc61be32021-10-14 17:09:57 -0700486 if (this->qinfo)
487 {
488 if (InDtype != DType_INT8)
489 {
490 ERROR_IF(this->qinfo->input_zp() != 0, "OpConv2d: zeropoint only for int8_t");
491 }
492 if (WeightDtype != DType_INT8)
493 {
494 ERROR_IF(this->qinfo->weight_zp() != 0, "OpConv2d: zeropoint only for int8_t");
495 }
496 }
497
Eric Kunzee5e26762020-10-13 16:11:07 -0700498 return 0;
499}
500
501template <DType InDtype, DType WeightDtype>
502int OpConv2d<InDtype, WeightDtype>::eval()
503{
504 int in_batch = this->input->getShape()[0];
505 int in_height = this->input->getShape()[1];
506 int in_width = this->input->getShape()[2];
507 int in_channels = this->input->getShape()[3];
508
509 int f_out_channels = this->weight->getShape()[0];
510 int f_height = this->weight->getShape()[1];
511 int f_width = this->weight->getShape()[2];
512 int f_in_channels = this->weight->getShape()[3];
513
514 int b_out_channels = this->bias->getShape()[0];
515
516 int out_batch = this->output->getShape()[0];
517 int out_height = this->output->getShape()[1];
518 int out_width = this->output->getShape()[2];
519 int out_channels = this->output->getShape()[3];
520
Kevin Chengacb550f2021-06-29 15:32:19 -0700521 ERROR_IF(in_batch != out_batch, "OpConv2d: tensor batch mismatch %d != %d", in_batch, out_batch);
522 ERROR_IF(f_in_channels != in_channels, "OpConv2d: tensor input channel mismatch %d != %d", f_in_channels,
523 in_channels);
524 ERROR_IF(f_out_channels != out_channels, "OpConv2d: tensor output channel mismatch %d != %d", f_out_channels,
525 out_channels);
526 ERROR_IF(b_out_channels != out_channels, "OpConv2d: bias channel mismatch %d != %d", b_out_channels, out_channels);
Eric Kunzee5e26762020-10-13 16:11:07 -0700527
528 int padding_top = this->attribute->padding()[0];
529 int padding_bottom = this->attribute->padding()[1];
530 int padding_left = this->attribute->padding()[2];
531 int padding_right = this->attribute->padding()[3];
532 int stride_h = this->attribute->stride()[0];
533 int stride_w = this->attribute->stride()[1];
534 int dilation_h = this->attribute->dilation()[0];
535 int dilation_w = this->attribute->dilation()[1];
536
537 DEBUG_INFO(OP,
538 "perform OpConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], output.shape=[%d,%d,%d,%d], "
539 "stride=[%d,%d], dilation=[%d,%d], padding=[%d,%d,%d,%d]",
540 in_batch, in_height, in_width, in_channels, f_height, f_width, f_in_channels, f_out_channels, out_batch,
541 out_height, out_width, out_channels, stride_h, stride_w, dilation_h, dilation_w, padding_top,
542 padding_bottom, padding_left, padding_right);
543
544 // GEMM-conv2d, left matrix is input, right matrix is weight
545 Eigen::array<Eigen::Index, 2> im2col_input_dims;
546 im2col_input_dims[0] = out_batch * out_height * out_width;
547 im2col_input_dims[1] = f_height * f_width * f_in_channels;
548
549 Eigen::array<Eigen::Index, 2> im2col_weight_dims;
550 im2col_weight_dims[0] = f_height * f_width * f_in_channels;
551 im2col_weight_dims[1] = f_out_channels;
552
553 Eigen::array<Eigen::Index, 2> bias_reshaped_dims;
554 bias_reshaped_dims[0] = 1;
555 bias_reshaped_dims[1] = b_out_channels;
556
557 Eigen::array<Eigen::Index, 4> weight_zp_bcast_dims;
558 weight_zp_bcast_dims[0] = f_height;
559 weight_zp_bcast_dims[1] = f_width;
560 weight_zp_bcast_dims[2] = f_in_channels;
561
562 Eigen::array<Eigen::Index, 2> bias_bcast_dims;
563 bias_bcast_dims[0] = out_batch * out_height * out_width;
564 bias_bcast_dims[1] = 1;
565
566 Eigen::array<Eigen::Index, 4> col2im_output_dims;
567 col2im_output_dims[0] = out_batch;
568 col2im_output_dims[1] = out_height;
569 col2im_output_dims[2] = out_width;
570 col2im_output_dims[3] = out_channels;
571
572 Eigen::array<Eigen::IndexPair<Eigen::Index>, 1> contract_dims = { Eigen::IndexPair<Eigen::Index>(1, 0) };
573
574 Eigen::array<std::pair<int32_t, int32_t>, 4> padding;
575 padding[0] = std::make_pair(0, 0);
576 padding[1] = std::make_pair(padding_top, padding_bottom);
577 padding[2] = std::make_pair(padding_left, padding_right);
578 padding[3] = std::make_pair(0, 0);
579
580 TIn input_val = this->input->getTensor();
581 TWeight weight_val = this->weight->getTensor();
582 if (this->qinfo)
583 {
584 input_val = input_val - (InEigenType)this->qinfo->input_zp();
585 weight_val = weight_val - (WeightEigenType)this->qinfo->weight_zp();
586 }
587
588 ETensor4<InEigenType> input_padded = input_val.pad(padding);
589
590 // extract_image_patches() output [N, KH, KW, H * W, C]
591 // need to transpose to [N, H * W, KH, KW, C]
592 ETensor5<InEigenType> input_extract_patches =
593 input_padded
594 .extract_image_patches(f_height, f_width, stride_h, stride_w, dilation_h, dilation_w, Eigen::PADDING_VALID)
595 .shuffle(Eigen::array<Eigen::Index, 5>{ 0, 3, 1, 2, 4 });
596
597 // reshape input to [N * H * W, KH * KW * C]
598 ETensor2<InEigenType> im2col_input = input_extract_patches.reshape(im2col_input_dims);
599
600 // transpose and reshape weight from [OC, H, W, IC] to [H * W * IC, OC]
601 ETensor2<WeightEigenType> im2col_weight =
602 weight_val.shuffle(Eigen::array<Eigen::Index, 4>({ 1, 2, 3, 0 })).reshape(im2col_weight_dims);
603
604 // don't need to apply bias_multiplier ( * bias_scale and >> bias_shift) since tflite already scale it
605 // and reshaped from [C] to [1, C], and broadcast to [N * H * W, C]
606 ETensor2<AccEigenType> bias_2d = this->bias->getTensor().reshape(bias_reshaped_dims).broadcast(bias_bcast_dims);
607
608 // output matrix is [N * H * W, C]
609 ETensor2<AccEigenType> contracted_result =
610 im2col_input.template cast<AccEigenType>().contract(im2col_weight.template cast<AccEigenType>(), contract_dims);
611
612 // adding bias
613 ETensor2<AccEigenType> biased_output = contracted_result + bias_2d.template cast<AccEigenType>();
614
615 // reshape back to [N, H, W, C]
616 this->output->getTensor() = biased_output.reshape(col2im_output_dims);
617
618 if (AccDtype == DType_INT48)
619 {
620 this->output->getTensor() = this->output->getTensor().cwiseMax((AccEigenType)AccQMin);
621 this->output->getTensor() = this->output->getTensor().cwiseMin((AccEigenType)AccQMax);
622 }
623
624 return GraphNode::eval();
625}
626
627template <DType InDtype, DType WeightDtype>
Kevin Cheng1533b852021-09-01 12:51:58 -0700628OpConv3d<InDtype, WeightDtype>::OpConv3d(SubgraphTraverser* sgt_,
629 TosaAttributeBase* attribute_,
630 TosaQuantInfoBase* qinfo_,
631 uint64_t id_)
632 : GraphNode(sgt_, Op_CONV3D, id_)
633{
634 setRequiredOperands(3, 1);
635 setRequiredRank(5);
636
637 INIT_ATTRIBUTE(Conv);
638 INIT_QINFO(Conv);
639}
640
641template <DType InDtype, DType WeightDtype>
642OpConv3d<InDtype, WeightDtype>::~OpConv3d()
643{
644 if (attribute)
645 delete attribute;
646 if (qinfo)
647 delete qinfo;
648}
649
650template <DType InDtype, DType WeightDtype>
651int OpConv3d<InDtype, WeightDtype>::checkTensorAttributes()
652{
653 if (validateRequiredOperands())
654 return 1;
655
656 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
657 {
658 return 1;
659 }
660
661 // 'bias' checked separatedly since it doens't make sense to make required rank ranging from 1 to 4
662 if (inputs[2]->getRank() != 1)
663 {
664 printNodeValidationError("OpConv3d: bias tensor must be rank 1");
665 }
666
Kevin Chengcc61be32021-10-14 17:09:57 -0700667 ERROR_IF(outputs[0]->getDtype() != AccDtype,
668 "OpFullyConnected: Output data type not supported for this configuration of operator");
669
Kevin Cheng1533b852021-09-01 12:51:58 -0700670 input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
671 weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
672 bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
673 output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]);
674
675 if (attribute->padding().size() != 6)
676 {
677 printNodeValidationError("OpConv3d: illegal size for attribute padding");
678 return 1;
679 }
680
681 if (attribute->stride().size() != 3)
682 {
683 printNodeValidationError("OpConv3d: illegal size for attribute stride");
684 return 1;
685 }
686
687 if (attribute->dilation().size() != 3)
688 {
689 printNodeValidationError("OpConv3d: illegal size for attribute dilation");
690 return 1;
691 }
692
Kevin Chengcc61be32021-10-14 17:09:57 -0700693 if (this->qinfo)
694 {
695 if (InDtype != DType_INT8)
696 {
697 ERROR_IF(this->qinfo->input_zp() != 0, "OpConv3d: zeropoint only for int8_t");
698 }
699 if (WeightDtype != DType_INT8)
700 {
701 ERROR_IF(this->qinfo->weight_zp() != 0, "OpConv3d: zeropoint only for int8_t");
702 }
703 }
704
Kevin Cheng1533b852021-09-01 12:51:58 -0700705 return 0;
706}
707
708template <DType InDtype, DType WeightDtype>
709int OpConv3d<InDtype, WeightDtype>::eval()
710{
711 int in_batch = this->input->getShape()[0];
712 int in_depth = this->input->getShape()[1];
713 int in_height = this->input->getShape()[2];
714 int in_width = this->input->getShape()[3];
715 int in_channels = this->input->getShape()[4];
716
717 int f_out_channels = this->weight->getShape()[0];
718 int f_depth = this->weight->getShape()[1];
719 int f_height = this->weight->getShape()[2];
720 int f_width = this->weight->getShape()[3];
721 int f_in_channels = this->weight->getShape()[4];
722
723 int b_out_channels = this->bias->getShape()[0];
724
725 int out_batch = this->output->getShape()[0];
726 int out_depth = this->output->getShape()[1];
727 int out_height = this->output->getShape()[2];
728 int out_width = this->output->getShape()[3];
729 int out_channels = this->output->getShape()[4];
730
731 ERROR_IF(in_batch != out_batch, "OpConv3d: tensor batch mismatch %d != %d", in_batch, out_batch);
732 ERROR_IF(f_in_channels != in_channels, "OpConv3d: tensor input channel mismatch %d != %d", f_in_channels,
733 in_channels);
734 ERROR_IF(f_out_channels != out_channels, "OpConv3d: tensor output channel mismatch %d != %d", f_out_channels,
735 out_channels);
736 ERROR_IF(b_out_channels != out_channels, "OpConv3d: bias channel mismatch %d != %d", b_out_channels, out_channels);
737
738 int padding_d0 = this->attribute->padding()[0];
739 int padding_d1 = this->attribute->padding()[1];
740 int padding_top = this->attribute->padding()[2];
741 int padding_bottom = this->attribute->padding()[3];
742 int padding_left = this->attribute->padding()[4];
743 int padding_right = this->attribute->padding()[5];
744 int stride_d = this->attribute->stride()[0];
745 int stride_h = this->attribute->stride()[1];
746 int stride_w = this->attribute->stride()[2];
747 int dilation_d = this->attribute->dilation()[0];
748 int dilation_h = this->attribute->dilation()[1];
749 int dilation_w = this->attribute->dilation()[2];
750
751 DEBUG_INFO(
752 OP,
753 "perform OpConv3d, input.shape=[%d,%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d,%d], output.shape=[%d,%d,%d,%d,%d], "
754 "stride=[%d,%d,%d], dilation=[%d,%d,%d], padding=[%d,%d,%d,%d,%d,%d]",
755 in_batch, in_depth, in_height, in_width, in_channels, f_out_channels, f_depth, f_height, f_width, f_in_channels,
756 out_batch, out_depth, out_height, out_width, out_channels, stride_d, stride_h, stride_w, dilation_d, dilation_h,
757 dilation_w, padding_d0, padding_d1, padding_top, padding_bottom, padding_left, padding_right);
758
759 Eigen::array<std::pair<int32_t, int32_t>, 5> padding;
760 padding[0] = std::make_pair(0, 0);
761 padding[1] = std::make_pair(padding_d0, padding_d1);
762 padding[2] = std::make_pair(padding_top, padding_bottom);
763 padding[3] = std::make_pair(padding_left, padding_right);
764 padding[4] = std::make_pair(0, 0);
765
766 TIn input_val = this->input->getTensor();
767 TWeight weight_val = this->weight->getTensor();
768 if (this->qinfo)
769 {
770 input_val = input_val - (InEigenType)this->qinfo->input_zp();
771 weight_val = weight_val - (WeightEigenType)this->qinfo->weight_zp();
772 }
773
774 ETensor5<InEigenType> input_padded = input_val.pad(padding);
775
776 // 1. initialize with bias
777 Eigen::array<Eigen::Index, 5> reshape_dim;
778 reshape_dim.fill(1);
779 reshape_dim[4] = b_out_channels;
780
781 Eigen::array<Eigen::Index, 5> bcast;
782 bcast[0] = out_batch;
783 bcast[1] = out_depth;
784 bcast[2] = out_height;
785 bcast[3] = out_width;
786 bcast[4] = 1;
787 this->output->getTensor() = this->bias->getTensor().reshape(reshape_dim).broadcast(bcast);
788
789 // 2. direct convolution
790 AccEigenType acc = 0;
791 int d_idx, h_idx, w_idx;
792
793 for (int ob = 0; ob < out_batch; ob++)
794 {
795 for (int od = 0; od < out_depth; od++)
796 {
797 for (int oh = 0; oh < out_height; oh++)
798 {
799 for (int ow = 0; ow < out_width; ow++)
800 {
801 for (int oc = 0; oc < out_channels; oc++)
802 {
803 acc = 0;
804 for (int fd = 0; fd < f_depth; fd++)
805 {
806 d_idx = od * stride_d + fd * dilation_d;
807 for (int fh = 0; fh < f_height; fh++)
808 {
809 h_idx = oh * stride_h + fh * dilation_h;
810 for (int fw = 0; fw < f_width; fw++)
811 {
812 w_idx = ow * stride_w + fw * dilation_w;
813 for (int ic = 0; ic < in_channels; ic++)
814 {
815 acc += ((AccEigenType)input_padded(ob, d_idx, h_idx, w_idx, ic) *
816 (AccEigenType)weight_val(oc, fd, fh, fw, ic));
817 }
818 }
819 }
820 }
821 this->output->getTensor()(ob, od, oh, ow, oc) = acc;
822 }
823 }
824 }
825 }
826 }
827
828 if (AccDtype == DType_INT48)
829 {
830 this->output->getTensor() = this->output->getTensor().cwiseMax((AccEigenType)AccQMin);
831 this->output->getTensor() = this->output->getTensor().cwiseMin((AccEigenType)AccQMax);
832 }
833
834 return GraphNode::eval();
835}
836
837template <DType InDtype, DType WeightDtype>
Kevin Chengacb550f2021-06-29 15:32:19 -0700838OpDepthwiseConv2d<InDtype, WeightDtype>::OpDepthwiseConv2d(SubgraphTraverser* sgt_,
839 TosaAttributeBase* attribute_,
Eric Kunzee5e26762020-10-13 16:11:07 -0700840 TosaQuantInfoBase* qinfo_,
841 uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -0700842 : GraphNode(sgt_, Op_DEPTHWISE_CONV2D, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700843{
844 setRequiredOperands(3, 1);
845 setRequiredRank(4);
846
Kevin Cheng93a16282021-08-31 16:14:03 -0700847 INIT_ATTRIBUTE(Conv);
Eric Kunzee5e26762020-10-13 16:11:07 -0700848 INIT_QINFO(Conv);
849}
850
851template <DType InDtype, DType WeightDtype>
852OpDepthwiseConv2d<InDtype, WeightDtype>::~OpDepthwiseConv2d()
853{
854 if (attribute)
855 delete attribute;
856 if (qinfo)
857 delete qinfo;
858}
859
860template <DType InDtype, DType WeightDtype>
861int OpDepthwiseConv2d<InDtype, WeightDtype>::checkTensorAttributes()
862{
863 if (validateRequiredOperands())
864 return 1;
865
866 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
867 {
868 return 1;
869 }
870
871 // 'bias' checked separatedly since it doens't make sense to make required rank ranging from 1 to 4
872 if (inputs[2]->getRank() != 1)
873 {
874 printNodeValidationError("OpDepthwiseConv2d: bias tensor must be rank 1");
875 }
876
Kevin Chengcc61be32021-10-14 17:09:57 -0700877 ERROR_IF(outputs[0]->getDtype() != AccDtype,
878 "OpFullyConnected: Output data type not supported for this configuration of operator");
879
Eric Kunzee5e26762020-10-13 16:11:07 -0700880 input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
881 weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
882 bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
883 output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]);
884
Eric Kunzee5e26762020-10-13 16:11:07 -0700885 if (attribute->padding().size() != 4)
886 {
887 printNodeValidationError("OpDepthwiseConv2d: illegal size for attribute padding");
888 return 1;
889 }
890
891 if (attribute->stride().size() != 2)
892 {
893 printNodeValidationError("OpDepthwiseConv2d: illegal size for attribute stride");
894 return 1;
895 }
896
897 if (attribute->dilation().size() != 2)
898 {
899 printNodeValidationError("OpDepthwiseConv2d: illegal size for attribute dilation");
900 return 1;
901 }
902
Kevin Chengcc61be32021-10-14 17:09:57 -0700903 if (this->qinfo)
904 {
905 if (InDtype != DType_INT8)
906 {
907 ERROR_IF(this->qinfo->input_zp() != 0, "OpDepthwiseConv2d: zeropoint only for int8_t");
908 }
909 if (WeightDtype != DType_INT8)
910 {
911 ERROR_IF(this->qinfo->weight_zp() != 0, "OpDepthwiseConv2d: zeropoint only for int8_t");
912 }
913 }
914
Eric Kunzee5e26762020-10-13 16:11:07 -0700915 return 0;
916}
917
918template <DType InDtype, DType WeightDtype>
919int OpDepthwiseConv2d<InDtype, WeightDtype>::eval()
920{
921 int in_batch = this->input->getShape()[0];
922 int in_height = this->input->getShape()[1];
923 int in_width = this->input->getShape()[2];
924 int in_channels = this->input->getShape()[3];
925
926 int f_height = this->weight->getShape()[0];
927 int f_width = this->weight->getShape()[1];
928 int f_in_channels = this->weight->getShape()[2];
929 int f_multiplier = this->weight->getShape()[3];
930
931 int b_out_channels = this->bias->getShape()[0];
932
933 int out_batch = this->output->getShape()[0];
934 int out_height = this->output->getShape()[1];
935 int out_width = this->output->getShape()[2];
936 int out_channels = this->output->getShape()[3];
937
Kevin Chengacb550f2021-06-29 15:32:19 -0700938 ERROR_IF(in_batch != out_batch, "OpDepthwiseConv2d: tensor batch mismatch %d != %d", in_batch, out_batch);
939 ERROR_IF(f_in_channels != in_channels, "OpDepthwiseConv2d: tensor input channel mismatch %d != %d", f_in_channels,
940 in_channels);
941 ERROR_IF(in_channels * f_multiplier != out_channels, "OpDepthwiseConv2d: tensor output channel mismatch %d != %d",
942 in_channels * f_multiplier, out_channels);
943 ERROR_IF(b_out_channels != out_channels, "OpDepthwiseConv2d: bias channels mismatch %d != %d", b_out_channels,
944 out_channels);
Eric Kunzee5e26762020-10-13 16:11:07 -0700945
946 int padding_top = this->attribute->padding()[0];
947 int padding_bottom = this->attribute->padding()[1];
948 int padding_left = this->attribute->padding()[2];
949 int padding_right = this->attribute->padding()[3];
950 int stride_h = this->attribute->stride()[0];
951 int stride_w = this->attribute->stride()[1];
952 int dilation_h = this->attribute->dilation()[0];
953 int dilation_w = this->attribute->dilation()[1];
954
955 DEBUG_INFO(OP,
956 "perform OpDepthwiseConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], "
957 "output.shape=[%d,%d,%d,%d], stride=[%d,%d], dilation=[%d,%d], padding=[%d,%d,%d,%d]",
958 in_batch, in_height, in_width, in_channels, f_height, f_width, f_in_channels, f_multiplier, out_batch,
959 out_height, out_width, out_channels, stride_h, stride_w, dilation_h, dilation_w, padding_top,
960 padding_bottom, padding_left, padding_right);
961
962 Eigen::array<std::pair<int32_t, int32_t>, 4> padding;
963 padding[0] = std::make_pair(0, 0);
964 padding[1] = std::make_pair(padding_top, padding_bottom);
965 padding[2] = std::make_pair(padding_left, padding_right);
966 padding[3] = std::make_pair(0, 0);
967
968 TIn input_val = this->input->getTensor();
969 TWeight weight_val = this->weight->getTensor();
970 if (this->qinfo)
971 {
972 input_val = input_val - (InEigenType)this->qinfo->input_zp();
973 weight_val = weight_val - (WeightEigenType)this->qinfo->weight_zp();
974 }
975
976 ETensor4<InEigenType> input_padded = input_val.pad(padding);
977
978 // GEMM doesn't fit well with DepthwiseConv2d
979 // 1. use extract_image_patches() to handle stride/dilation/padding
980 // 2. perform direct convolution
981
982 // 1. extract_image_patches() output [N, KH, KW, OH * OW, IC]
983 ETensor5<InEigenType> input_extract_patches = input_padded.extract_image_patches(
984 f_height, f_width, stride_h, stride_w, dilation_h, dilation_w, Eigen::PADDING_VALID);
985
986 Eigen::array<Eigen::Index, 4> reshape_dim;
987 reshape_dim.fill(1);
988 reshape_dim[3] = b_out_channels;
989
990 Eigen::array<Eigen::Index, 4> bcast;
991 bcast[0] = out_batch;
992 bcast[1] = out_height;
993 bcast[2] = out_width;
994 bcast[3] = 1;
995
996 // initialize with bias
997 this->output->getTensor() = this->bias->getTensor().reshape(reshape_dim).broadcast(bcast);
998
999 // 2. direct depthwise convolution
1000 for (int ob = 0; ob < out_batch; ob++)
1001 {
1002 for (int oh = 0; oh < out_height; oh++)
1003 {
1004 for (int ow = 0; ow < out_width; ow++)
1005 {
1006 for (int ic = 0; ic < in_channels; ic++)
1007 {
1008 for (int cm = 0; cm < f_multiplier; cm++)
1009 {
1010 for (int fh = 0; fh < f_height; fh++)
1011 {
1012 for (int fw = 0; fw < f_width; fw++)
1013 {
1014 this->output->getTensor()(ob, oh, ow, ic * f_multiplier + cm) +=
1015 ((AccEigenType)input_extract_patches(ob, fh, fw, ow * out_height + oh, ic) *
1016 (AccEigenType)weight_val(fh, fw, ic, cm));
1017 }
1018 }
1019 }
1020 }
1021 }
1022 }
1023 }
1024
1025 if (AccDtype == DType_INT48)
1026 {
1027 this->output->getTensor() = this->output->getTensor().cwiseMax((AccEigenType)AccQMin);
1028 this->output->getTensor() = this->output->getTensor().cwiseMin((AccEigenType)AccQMax);
1029 }
1030
1031 return GraphNode::eval();
1032}
1033
1034template <DType InDtype, DType WeightDtype>
Kevin Chengacb550f2021-06-29 15:32:19 -07001035OpFullyConnected<InDtype, WeightDtype>::OpFullyConnected(SubgraphTraverser* sgt_,
1036 TosaAttributeBase* attribute_,
Eric Kunzee5e26762020-10-13 16:11:07 -07001037 TosaQuantInfoBase* qinfo_,
1038 uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -07001039 : GraphNode(sgt_, Op_FULLY_CONNECTED, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -07001040{
1041 setRequiredOperands(3, 1);
1042 setRequiredRank(2);
1043
1044 INIT_QINFO(Conv);
1045}
1046
1047template <DType InDtype, DType WeightDtype>
1048OpFullyConnected<InDtype, WeightDtype>::~OpFullyConnected()
1049{
1050 if (qinfo)
1051 delete qinfo;
1052}
1053
1054template <DType InDtype, DType WeightDtype>
1055int OpFullyConnected<InDtype, WeightDtype>::checkTensorAttributes()
1056{
1057 if (validateRequiredOperands())
1058 return 1;
1059
1060 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
1061 {
1062 return 1;
1063 }
1064
1065 input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
1066 weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
1067 bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
1068
1069 if (input->getShape()[1] != weight->getShape()[1])
1070 {
1071 printNodeValidationError("OpFullyConnected operator input.shape[1] should match weight.shape[1]");
1072 return 1;
1073 }
1074
1075 if (weight->getShape()[0] != bias->getShape()[0])
1076 {
1077 printNodeValidationError("OpFullyConnected operator bias.shape[0] should match weight.shape[0]");
1078 return 1;
1079 }
1080
Kevin Chengcc61be32021-10-14 17:09:57 -07001081 ERROR_IF(outputs[0]->getDtype() != AccDtype,
1082 "OpFullyConnected: Output data type not supported for this configuration of operator");
1083
Eric Kunzee5e26762020-10-13 16:11:07 -07001084 output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]);
1085
Kevin Chengcc61be32021-10-14 17:09:57 -07001086 if (this->qinfo)
1087 {
1088 if (InDtype != DType_INT8)
1089 {
1090 ERROR_IF(this->qinfo->input_zp() != 0, "OpFullyConnected: zeropoint only for int8_t");
1091 }
1092 if (WeightDtype != DType_INT8)
1093 {
1094 ERROR_IF(this->qinfo->weight_zp() != 0, "OpFullyConnected: zeropoint only for int8_t");
1095 }
1096 }
1097
Eric Kunzee5e26762020-10-13 16:11:07 -07001098 return 0;
1099}
1100
1101template <DType InDtype, DType WeightDtype>
1102int OpFullyConnected<InDtype, WeightDtype>::eval()
1103{
1104 typedef Eigen::Tensor<int, 1>::DimensionPair DimPair;
1105 Eigen::array<DimPair, 1> dims{ { DimPair(1, 0) } };
1106
1107 Eigen::array<Eigen::Index, 2> weight_shuffle{ 1, 0 };
1108
1109 Eigen::array<Eigen::Index, 2> bias_reshape;
1110 bias_reshape[0] = 1;
1111 bias_reshape[1] = this->bias->getShape()[0];
1112
1113 Eigen::array<Eigen::Index, 2> bias_bcast;
1114 bias_bcast[0] = this->input->getShape()[0];
1115 bias_bcast[1] = 1;
1116
1117 TIn input_val = this->input->getTensor();
1118 TWeight weight_val = this->weight->getTensor().shuffle(weight_shuffle);
1119 if (this->qinfo)
1120 {
1121 input_val = input_val - (InEigenType)this->qinfo->input_zp();
1122 weight_val = weight_val - (WeightEigenType)this->qinfo->weight_zp();
1123 }
1124
1125 this->output->getTensor() =
1126 input_val.template cast<AccEigenType>().contract(weight_val.template cast<AccEigenType>(), dims) +
1127 this->bias->getTensor().reshape(bias_reshape).broadcast(bias_bcast);
1128
1129 if (AccDtype == DType_INT48)
1130 {
1131 this->output->getTensor() = this->output->getTensor().cwiseMax((AccEigenType)AccQMin);
1132 this->output->getTensor() = this->output->getTensor().cwiseMin((AccEigenType)AccQMax);
1133 }
1134 return GraphNode::eval();
1135}
1136
1137template <DType Dtype>
Kevin Chengacb550f2021-06-29 15:32:19 -07001138OpMatMul<Dtype>::OpMatMul(SubgraphTraverser* sgt_,
1139 TosaAttributeBase* attribute_,
1140 TosaQuantInfoBase* qinfo_,
1141 uint64_t id_)
1142 : GraphNode(sgt_, Op_MATMUL, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -07001143{
1144 setRequiredOperands(2, 1);
Kevin Cheng2d60f002021-06-09 14:18:32 -07001145 setRequiredRank(3);
Eric Kunzee5e26762020-10-13 16:11:07 -07001146
1147 INIT_QINFO(MatMul);
1148}
1149
1150template <DType Dtype>
1151OpMatMul<Dtype>::~OpMatMul()
1152{
1153 if (qinfo)
1154 delete qinfo;
1155}
1156
1157template <DType Dtype>
1158int OpMatMul<Dtype>::checkTensorAttributes()
1159{
1160 if (validateRequiredOperands())
1161 return 1;
1162
1163 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
1164 {
1165 return 1;
1166 }
1167
Kevin Chengcc61be32021-10-14 17:09:57 -07001168 ERROR_IF(outputs[0]->getDtype() != AccDtype,
1169 "OpFullyConnected: Output data type not supported for this configuration of operator");
1170
Kevin Cheng2d60f002021-06-09 14:18:32 -07001171 a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
1172 b = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
1173 output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]);
Eric Kunzee5e26762020-10-13 16:11:07 -07001174
Kevin Cheng2d60f002021-06-09 14:18:32 -07001175 ASSERT_MEM(a && b && output);
1176
1177 // a: [N, H, C]
1178 // b: [N, C, W]
1179 // c: [N, H, W]
1180
1181 // Check N
1182 if (a->getShape()[0] != b->getShape()[0] || a->getShape()[0] != output->getShape()[0])
Eric Kunzee5e26762020-10-13 16:11:07 -07001183 {
Kevin Cheng2d60f002021-06-09 14:18:32 -07001184 printNodeValidationError("OpMatMul operator a.shape[0], b.shape[0] and output.shape[0] should match");
Eric Kunzee5e26762020-10-13 16:11:07 -07001185 return 1;
1186 }
Kevin Cheng2d60f002021-06-09 14:18:32 -07001187 N = a->getShape()[0];
Eric Kunzee5e26762020-10-13 16:11:07 -07001188
Kevin Cheng2d60f002021-06-09 14:18:32 -07001189 // Check C
1190 if (a->getShape()[2] != b->getShape()[1])
1191 {
1192 printNodeValidationError("OpMatMul operator a.shape[2] should match b.shape[1]");
1193 return 1;
1194 }
1195 C = a->getShape()[2];
1196
1197 // Check H
1198 if (a->getShape()[1] != output->getShape()[1])
1199 {
1200 printNodeValidationError("OpMatMul operator a.shape[1] should match output.shape[1]");
1201 return 1;
1202 }
1203 H = a->getShape()[1];
1204
1205 // Check W
1206 if (b->getShape()[2] != output->getShape()[2])
1207 {
1208 printNodeValidationError("OpMatMul operator output.shape[2] should match output.shape[2]");
1209 return 1;
1210 }
1211 W = b->getShape()[2];
Eric Kunzee5e26762020-10-13 16:11:07 -07001212
Kevin Chengcc61be32021-10-14 17:09:57 -07001213 if (Dtype != DType_INT8)
1214 {
1215 ERROR_IF(this->qinfo->a_zp() != 0, "OpMatMul: zeropoint only for int8_t");
1216 ERROR_IF(this->qinfo->b_zp() != 0, "OpMatMul: zeropoint only for int8_t");
1217 }
1218
Eric Kunzee5e26762020-10-13 16:11:07 -07001219 return 0;
1220}
1221
1222template <DType Dtype>
1223int OpMatMul<Dtype>::eval()
1224{
1225 typedef Eigen::Tensor<int, 1>::DimensionPair DimPair;
1226 Eigen::array<DimPair, 1> dims{ { DimPair(1, 0) } };
1227
1228 TIn a_val = this->a->getTensor();
1229 TIn b_val = this->b->getTensor();
1230 if (this->qinfo)
1231 {
1232 a_val = a_val - (InEigenType)this->qinfo->a_zp();
1233 b_val = b_val - (InEigenType)this->qinfo->b_zp();
1234 }
1235
Kevin Cheng2d60f002021-06-09 14:18:32 -07001236 Eigen::array<Eigen::Index, 2> a_rank2_shape({ H, C });
1237 Eigen::array<Eigen::Index, 2> b_rank2_shape({ C, W });
1238 Eigen::array<Eigen::Index, 3> output_rank3_shape({ 1, H, W });
1239
1240 Eigen::array<Eigen::Index, 3> a_size_array({ 1, H, C });
1241 Eigen::array<Eigen::Index, 3> b_size_array({ 1, C, W });
1242
1243 Eigen::array<Eigen::Index, 3> a_begin_array({ 0, 0, 0 });
1244 Eigen::array<Eigen::Index, 3> b_begin_array({ 0, 0, 0 });
1245
1246 // Iterate N dimension.
1247 for (int i = 0; i < N; i++)
1248 {
1249 a_begin_array[0] = i;
1250 b_begin_array[0] = i;
1251
1252 TInRank2 a_rank2_val = a_val.slice(a_begin_array, a_size_array).reshape(a_rank2_shape);
1253 TInRank2 b_rank2_val = b_val.slice(b_begin_array, b_size_array).reshape(b_rank2_shape);
1254 TAccRank2 output_rank2_val =
1255 a_rank2_val.template cast<AccEigenType>().contract(b_rank2_val.template cast<AccEigenType>(), dims);
1256 TAcc output_rank3_val = output_rank2_val.reshape(output_rank3_shape);
1257 if (i == 0)
1258 {
1259 this->output->getTensor() = output_rank3_val;
1260 }
1261 else
1262 {
1263 TAcc temp = this->output->getTensor().concatenate(output_rank3_val, 0);
1264 this->output->getTensor() = temp;
1265 }
1266 }
Eric Kunzee5e26762020-10-13 16:11:07 -07001267
1268 if (AccDtype == DType_INT48)
1269 {
Kevin Cheng2d60f002021-06-09 14:18:32 -07001270 this->output->getTensor() = this->output->getTensor().cwiseMax((AccEigenType)AccQMin);
1271 this->output->getTensor() = this->output->getTensor().cwiseMin((AccEigenType)AccQMax);
Eric Kunzee5e26762020-10-13 16:11:07 -07001272 }
1273
1274 return GraphNode::eval();
1275}
1276
1277template <DType Dtype>
Kevin Chengacb550f2021-06-29 15:32:19 -07001278OpMaxPool2d<Dtype>::OpMaxPool2d(SubgraphTraverser* sgt_,
1279 TosaAttributeBase* attribute_,
1280 TosaQuantInfoBase* qinfo_,
1281 uint64_t id_)
1282 : GraphNode(sgt_, Op_MAX_POOL2D, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -07001283{
1284 setRequiredOperands(1, 1);
1285 setRequiredRank(4);
1286
Kevin Cheng93a16282021-08-31 16:14:03 -07001287 INIT_ATTRIBUTE(Pool);
Eric Kunzee5e26762020-10-13 16:11:07 -07001288}
1289
1290template <DType Dtype>
1291OpMaxPool2d<Dtype>::~OpMaxPool2d()
1292{
1293 if (attribute)
1294 delete attribute;
1295}
1296
1297template <DType Dtype>
1298int OpMaxPool2d<Dtype>::checkTensorAttributes()
1299{
1300 if (validateRequiredOperands())
1301 return 1;
1302
1303 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
1304 {
1305 return 1;
1306 }
1307
1308 if (inputs[0]->matchType(*outputs[0]))
1309 {
1310 printNodeValidationError("OpMaxPool2d: input and output tensor type mismatch");
1311 return 1;
1312 }
1313
1314 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
1315 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
1316
Kevin Cheng7eb93d72021-10-09 01:26:08 +00001317 std::string msg;
1318 if (check_pool2d_attribute_common(attribute, in->getShape(), out->getShape(), msg))
Eric Kunzee5e26762020-10-13 16:11:07 -07001319 {
Kevin Cheng7eb93d72021-10-09 01:26:08 +00001320 msg = "OpMaxPool2d: " + msg;
1321 printNodeValidationError(msg.c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -07001322 return 1;
1323 }
1324
1325 return 0;
1326}
1327
1328template <DType Dtype>
1329int OpMaxPool2d<Dtype>::eval()
1330{
1331 int in_batch = this->in->getShape()[0];
1332 int in_height = this->in->getShape()[1];
1333 int in_width = this->in->getShape()[2];
1334 int in_channels = this->in->getShape()[3];
1335
1336 int out_batch = this->out->getShape()[0];
1337 int out_height = this->out->getShape()[1];
1338 int out_width = this->out->getShape()[2];
1339 int out_channels = this->out->getShape()[3];
1340
Kevin Chengacb550f2021-06-29 15:32:19 -07001341 ERROR_IF(in_batch != out_batch, "OpMaxPool2d: tensor batch mismatch %d != %d", in_batch, out_batch);
1342 ERROR_IF(in_channels != out_channels, "OpMaxPool2d: tensor channel mismatch %d != %d", in_channels, out_channels);
Eric Kunzee5e26762020-10-13 16:11:07 -07001343
1344 int padding_top = this->attribute->padding()[0];
1345 int padding_bottom = this->attribute->padding()[1];
1346 int padding_left = this->attribute->padding()[2];
1347 int padding_right = this->attribute->padding()[3];
1348 int kernel_h = this->attribute->kernel()[0];
1349 int kernel_w = this->attribute->kernel()[1];
1350 int stride_h = this->attribute->stride()[0];
1351 int stride_w = this->attribute->stride()[1];
1352
1353 DEBUG_INFO(OP,
1354 "perform MaxPool2d, input.shape=[%d,%d,%d,%d], output.shape=[%d,%d,%d,%d], kernel=[%d,%d], "
1355 "stride=[%d,%d], padding=[%d,%d,%d,%d]",
1356 in_batch, in_height, in_width, in_channels, out_batch, out_height, out_width, out_channels, kernel_h,
1357 kernel_w, stride_h, stride_w, padding_top, padding_bottom, padding_left, padding_right);
1358
1359 Eigen::array<Eigen::Index, 2> im2col_input_dims;
1360 im2col_input_dims[0] = kernel_h * kernel_w;
1361 im2col_input_dims[1] = out_batch * out_height * out_width * out_channels;
1362
1363 Eigen::array<Eigen::Index, 4> col2im_output_dims;
1364 col2im_output_dims[0] = out_batch;
1365 col2im_output_dims[1] = out_height;
1366 col2im_output_dims[2] = out_width;
1367 col2im_output_dims[3] = out_channels;
1368
1369 Eigen::array<std::pair<int32_t, int32_t>, 4> padding;
1370 padding[0] = std::make_pair(0, 0);
1371 padding[1] = std::make_pair(padding_top, padding_bottom);
1372 padding[2] = std::make_pair(padding_left, padding_right);
1373 padding[3] = std::make_pair(0, 0);
1374
1375 ETensor4<InEigenType> input_padded = this->in->getTensor().pad(padding, std::numeric_limits<InEigenType>::lowest());
1376
1377 // extract_image_patches() output [N, KH, KW, H * W, C]
1378 // transpose to [KH, KW, N, H * W, C]
1379 // reshape to [KH * KW, N * H * W * C]
1380 //
1381 // Set the padding value to be the most negative value that can be
1382 // represented by the datatype to ensure that any padding values will be equal
1383 // to or smaller than the actual maximum in the KH x KW patch.
1384 ETensor2<InEigenType> input_extract_patches =
1385 input_padded
1386 .extract_image_patches(kernel_h, kernel_w, stride_h, stride_w, 1, 1, Eigen::PADDING_VALID,
1387 std::numeric_limits<InEigenType>::lowest())
1388 .shuffle(Eigen::array<Eigen::Index, 5>{ 1, 2, 0, 3, 4 })
1389 .reshape(im2col_input_dims);
1390
1391 // Get the maximum of the KHxHW patches along axis 0
1392 Eigen::Tensor<DenseIndex, 1> tensor_argmax = input_extract_patches.argmax(0);
1393
1394 // 1D result with [N * H * W * C]
1395 ETensor1<OutEigenType> out_1d(this->out->getElementCount());
1396
1397 // index input_patches with argmax array should give the result
1398 for (size_t i = 0; i < this->out->getElementCount(); i++)
1399 {
1400 out_1d(i) = (OutEigenType)input_extract_patches(tensor_argmax(i), i);
1401 }
1402
1403 // reshape result to [N, H, W, C]
1404 this->out->getTensor() = out_1d.reshape(col2im_output_dims);
1405
1406 return GraphNode::eval();
1407}
1408
Kevin Chengcc61be32021-10-14 17:09:57 -07001409template <DType InDtype, DType WeightDtype>
1410OpTransposeConv2d<InDtype, WeightDtype>::OpTransposeConv2d(SubgraphTraverser* sgt_,
1411 TosaAttributeBase* attribute_,
1412 TosaQuantInfoBase* qinfo_,
1413 uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -07001414 : GraphNode(sgt_, Op_TRANSPOSE_CONV2D, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -07001415{
1416 setRequiredOperands(3, 1);
1417 setRequiredRank(4);
1418
Kevin Cheng93a16282021-08-31 16:14:03 -07001419 INIT_ATTRIBUTE(TransposeConv);
Eric Kunzee5e26762020-10-13 16:11:07 -07001420 INIT_QINFO(Conv);
1421}
1422
Kevin Chengcc61be32021-10-14 17:09:57 -07001423template <DType InDtype, DType WeightDtype>
1424OpTransposeConv2d<InDtype, WeightDtype>::~OpTransposeConv2d()
Eric Kunzee5e26762020-10-13 16:11:07 -07001425{
1426 if (attribute)
1427 delete attribute;
1428 if (qinfo)
1429 delete qinfo;
1430}
1431
Kevin Chengcc61be32021-10-14 17:09:57 -07001432template <DType InDtype, DType WeightDtype>
1433int OpTransposeConv2d<InDtype, WeightDtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -07001434{
1435 if (validateRequiredOperands())
1436 return 1;
1437
1438 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
1439 {
1440 return 1;
1441 }
1442
Kevin Chengcc61be32021-10-14 17:09:57 -07001443 ERROR_IF(outputs[0]->getDtype() != AccDtype,
1444 "OpFullyConnected: Output data type not supported for this configuration of operator");
1445
Eric Kunzee5e26762020-10-13 16:11:07 -07001446 input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
1447 weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
1448 bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
1449 output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]);
1450
Eric Kunzee5e26762020-10-13 16:11:07 -07001451 if (attribute->outpad().size() != 2)
1452 {
1453 printNodeValidationError("OpTransposeConv2d: illegal size for attribute outpad");
1454 return 1;
1455 }
1456
1457 if (attribute->stride().size() != 2)
1458 {
1459 printNodeValidationError("OpTransposeConv2d: illegal size for attribute stride");
1460 return 1;
1461 }
1462
1463 if (attribute->dilation().size() != 2)
1464 {
1465 printNodeValidationError("OpTransposeConv2d: illegal size for attribute dilation");
1466 return 1;
1467 }
1468
1469 if (attribute->output_shape().size() != 4)
1470 {
1471 printNodeValidationError("OpTransposeConv2d: illegal size for attribute output_shape");
1472 return 1;
1473 }
1474
1475 for (int d = 0; d < 4; d++)
1476 {
1477 if (attribute->output_shape()[d] != this->output->getShape()[d])
1478 {
1479 printNodeValidationError("OpTransposeConv2d: illegal size for attribute output_shape");
1480 return 1;
1481 }
1482 }
1483
Kevin Chengcc61be32021-10-14 17:09:57 -07001484 if (this->qinfo)
1485 {
1486 if (InDtype != DType_INT8)
1487 {
1488 ERROR_IF(this->qinfo->input_zp() != 0, "OpTransposeConv2d: zeropoint only for int8_t");
1489 }
1490 if (WeightDtype != DType_INT8)
1491 {
1492 ERROR_IF(this->qinfo->weight_zp() != 0, "OpTransposeConv2d: zeropoint only for int8_t");
1493 }
1494 }
1495
Eric Kunzee5e26762020-10-13 16:11:07 -07001496 return 0;
1497}
1498
Kevin Chengcc61be32021-10-14 17:09:57 -07001499template <DType InDtype, DType WeightDtype>
1500int OpTransposeConv2d<InDtype, WeightDtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -07001501{
1502 int in_batch = this->input->getShape()[0];
1503 int in_height = this->input->getShape()[1];
1504 int in_width = this->input->getShape()[2];
1505 int in_channels = this->input->getShape()[3];
1506
1507 int f_out_channels = this->weight->getShape()[0];
1508 int f_height = this->weight->getShape()[1];
1509 int f_width = this->weight->getShape()[2];
1510 int f_in_channels = this->weight->getShape()[3];
1511
1512 int b_out_channels = this->bias->getShape()[0];
1513
1514 int out_batch = this->output->getShape()[0];
1515 int out_height = this->output->getShape()[1];
1516 int out_width = this->output->getShape()[2];
1517 int out_channels = this->output->getShape()[3];
1518
1519 int padding_top = this->attribute->outpad()[0];
1520 int padding_left = this->attribute->outpad()[1];
1521 int stride_h = this->attribute->stride()[0];
1522 int stride_w = this->attribute->stride()[1];
1523 int dilation_h = this->attribute->dilation()[0];
1524 int dilation_w = this->attribute->dilation()[1];
1525
Kevin Chengacb550f2021-06-29 15:32:19 -07001526 ERROR_IF(in_batch != out_batch, "OpTransposeConv2d: tensor batch mismatch %d != %d", in_batch, out_batch);
1527 ERROR_IF(f_in_channels != in_channels, "OpTransposeConv2d: tensor input channel mismatch %d != %d", f_in_channels,
1528 in_channels);
1529 ERROR_IF(f_out_channels != out_channels, "OpTransposeConv2d: tensor output channel mismatch %d != %d",
1530 f_out_channels, out_channels);
1531 ERROR_IF(b_out_channels != out_channels, "OpDepthwiseConv2d: bias channels mismatch %d != %d", b_out_channels,
1532 out_channels);
Eric Kunzee5e26762020-10-13 16:11:07 -07001533
1534 DEBUG_INFO(OP,
1535 "perform OpTransposeConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], "
1536 "output.shape=[%d,%d,%d,%d], stride=[%d,%d], dilation=[%d,%d], padding=[%d,%d]",
1537 in_batch, in_height, in_width, in_channels, f_height, f_width, f_out_channels, f_in_channels, out_batch,
1538 out_height, out_width, out_channels, stride_h, stride_w, dilation_h, dilation_w, padding_top,
1539 padding_left);
1540
1541 TIn input_val = this->input->getTensor();
1542 TWeight weight_val = this->weight->getTensor();
1543 if (this->qinfo)
1544 {
1545 input_val = input_val - (InEigenType)this->qinfo->input_zp();
1546 weight_val = weight_val - (WeightEigenType)this->qinfo->weight_zp();
1547 }
1548
1549 Eigen::array<Eigen::Index, 4> reshape_dim;
1550 reshape_dim.fill(1);
1551 reshape_dim[3] = b_out_channels;
1552
1553 Eigen::array<Eigen::Index, 4> bcast;
1554 bcast[0] = out_batch;
1555 bcast[1] = out_height;
1556 bcast[2] = out_width;
1557 bcast[3] = 1;
1558
1559 // initialize with bias
1560 this->output->getTensor() = this->bias->getTensor().reshape(reshape_dim).broadcast(bcast);
1561
1562 int out_x_origin, out_y_origin;
1563 int out_x, out_y;
1564
1565 // reference implementation from: tensorflow/tensorflow/lite/kernels/internal/reference/reference_ops.h
1566 for (int ob = 0; ob < out_batch; ob++)
1567 {
1568 for (int ih = 0; ih < in_height; ih++)
1569 {
1570 for (int iw = 0; iw < in_width; iw++)
1571 {
1572 out_x_origin = iw * stride_w - padding_left;
1573 out_y_origin = ih * stride_h - padding_top;
1574 for (int ic = 0; ic < in_channels; ic++)
1575 {
1576 for (int fh = 0; fh < f_height; fh++)
1577 {
1578 for (int fw = 0; fw < f_width; fw++)
1579 {
1580 out_x = out_x_origin + fw * dilation_w;
1581 out_y = out_y_origin + fh * dilation_h;
1582 for (int oc = 0; oc < out_channels; oc++)
1583 {
1584 if ((out_x >= 0 && out_x < out_width) && (out_y >= 0 && out_y < out_height))
1585 {
1586 this->output->getTensor()(ob, out_y, out_x, oc) +=
1587 ((AccEigenType)input_val(ob, ih, iw, ic) *
1588 (AccEigenType)weight_val(oc, fh, fw, ic));
1589 }
1590 }
1591 }
1592 }
1593 }
1594 }
1595 }
1596 }
1597
1598 if (AccDtype == DType_INT48)
1599 {
1600 this->output->getTensor() = this->output->getTensor().cwiseMax((AccEigenType)AccQMin);
1601 this->output->getTensor() = this->output->getTensor().cwiseMin((AccEigenType)AccQMax);
1602 }
1603
1604 return GraphNode::eval();
1605}
1606
1607// template explicit instantiation
1608DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FLOAT);
Kevin Cheng3a478572021-01-22 17:21:02 -08001609DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -07001610DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16);
1611
1612DEF_INSTANTIATE_ONE_TYPE(OpAvgPool2d, FLOAT)
Kevin Cheng3a478572021-01-22 17:21:02 -08001613DEF_INSTANTIATE_ONE_TYPE(OpAvgPool2d, INT8)
Eric Kunzee5e26762020-10-13 16:11:07 -07001614DEF_INSTANTIATE_ONE_TYPE(OpAvgPool2d, INT16)
1615
1616DEF_INSTANTIATE_TWO_TYPE(OpConv2d, FLOAT, FLOAT);
Kevin Cheng3a478572021-01-22 17:21:02 -08001617DEF_INSTANTIATE_TWO_TYPE(OpConv2d, INT8, INT4);
1618DEF_INSTANTIATE_TWO_TYPE(OpConv2d, INT8, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -07001619DEF_INSTANTIATE_TWO_TYPE(OpConv2d, INT16, INT8);
1620
Kevin Cheng1533b852021-09-01 12:51:58 -07001621DEF_INSTANTIATE_TWO_TYPE(OpConv3d, FLOAT, FLOAT);
1622DEF_INSTANTIATE_TWO_TYPE(OpConv3d, INT8, INT4);
1623DEF_INSTANTIATE_TWO_TYPE(OpConv3d, INT8, INT8);
1624DEF_INSTANTIATE_TWO_TYPE(OpConv3d, INT16, INT8);
1625
Eric Kunzee5e26762020-10-13 16:11:07 -07001626DEF_INSTANTIATE_TWO_TYPE(OpDepthwiseConv2d, FLOAT, FLOAT);
Kevin Cheng3a478572021-01-22 17:21:02 -08001627DEF_INSTANTIATE_TWO_TYPE(OpDepthwiseConv2d, INT8, INT4);
1628DEF_INSTANTIATE_TWO_TYPE(OpDepthwiseConv2d, INT8, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -07001629DEF_INSTANTIATE_TWO_TYPE(OpDepthwiseConv2d, INT16, INT8);
1630
1631DEF_INSTANTIATE_TWO_TYPE(OpFullyConnected, FLOAT, FLOAT);
Kevin Cheng3a478572021-01-22 17:21:02 -08001632DEF_INSTANTIATE_TWO_TYPE(OpFullyConnected, INT8, INT4);
1633DEF_INSTANTIATE_TWO_TYPE(OpFullyConnected, INT8, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -07001634DEF_INSTANTIATE_TWO_TYPE(OpFullyConnected, INT16, INT8);
1635
Kevin Cheng3a478572021-01-22 17:21:02 -08001636DEF_INSTANTIATE_ONE_TYPE(OpMatMul, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -07001637DEF_INSTANTIATE_ONE_TYPE(OpMatMul, INT16);
1638DEF_INSTANTIATE_ONE_TYPE(OpMatMul, FLOAT);
1639
1640DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FLOAT);
Kevin Cheng3a478572021-01-22 17:21:02 -08001641DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -07001642DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT16);
1643
1644DEF_INSTANTIATE_TWO_TYPE(OpTransposeConv2d, FLOAT, FLOAT);
Kevin Cheng3a478572021-01-22 17:21:02 -08001645DEF_INSTANTIATE_TWO_TYPE(OpTransposeConv2d, INT8, INT4);
1646DEF_INSTANTIATE_TWO_TYPE(OpTransposeConv2d, INT8, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -07001647DEF_INSTANTIATE_TWO_TYPE(OpTransposeConv2d, INT16, INT8);