Remove accumulator attributes from all but AVG_POOL2D

Signed-off-by: James Ward <james.ward@arm.com>
Change-Id: If67f503a1848967bc1671646c3011d055b622c52
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc
index dff9e08..4663c47 100644
--- a/reference_model/src/ops/tensor_ops.cc
+++ b/reference_model/src/ops/tensor_ops.cc
@@ -541,8 +541,8 @@
     return GraphNode::eval();
 }
 
-template <DType InDtype, DType WeightDtype, DType AccDtype>
-OpConv2d<InDtype, WeightDtype, AccDtype>::OpConv2d(SubgraphTraverser* sgt_,
+template <DType InDtype, DType WeightDtype, DType OutDtype>
+OpConv2d<InDtype, WeightDtype, OutDtype>::OpConv2d(SubgraphTraverser* sgt_,
                                          TosaAttributeBase* attribute_,
                                          uint64_t id_)
     : GraphNode(sgt_, Op_CONV2D, id_)
@@ -553,15 +553,15 @@
     INIT_ATTRIBUTE(Conv);
 }
 
-template <DType InDtype, DType WeightDtype, DType AccDtype>
-OpConv2d<InDtype, WeightDtype, AccDtype>::~OpConv2d()
+template <DType InDtype, DType WeightDtype, DType OutDtype>
+OpConv2d<InDtype, WeightDtype, OutDtype>::~OpConv2d()
 {
     if (attribute)
         delete attribute;
 }
 
-template <DType InDtype, DType WeightDtype, DType AccDtype>
-int OpConv2d<InDtype, WeightDtype, AccDtype>::checkTensorAttributes()
+template <DType InDtype, DType WeightDtype, DType OutDtype>
+int OpConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
 {
     if (validateRequiredOperands())
         return 1;
@@ -577,7 +577,7 @@
         printNodeValidationError("OpConv2d: bias tensor must be rank 1");
     }
 
-    ERROR_IF(outputs[0]->getDtype() != AccDtype,
+    ERROR_IF(outputs[0]->getDtype() != OutDtype,
                 "OpConv2d: Output data type not supported for this configuration of operator");
 
     input  = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
@@ -597,8 +597,8 @@
     return 0;
 }
 
-template <DType InDtype, DType WeightDtype, DType AccDtype>
-int OpConv2d<InDtype, WeightDtype, AccDtype>::eval()
+template <DType InDtype, DType WeightDtype, DType OutDtype>
+int OpConv2d<InDtype, WeightDtype, OutDtype>::eval()
 {
     int in_batch    = this->input->getShape()[0];
     int in_height   = this->input->getShape()[1];
@@ -634,14 +634,12 @@
     int dilation_h     = this->attribute->dilation()[0];
     int dilation_w     = this->attribute->dilation()[1];
 
-    tosa::DType accum_dtype       = (tosa::DType)this->attribute->accum_dtype();
-
     DEBUG_INFO(OP,
                "perform OpConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], output.shape=[%d,%d,%d,%d], "
-               "stride=[%d,%d], dilation=[%d,%d], pad=[%d,%d,%d,%d], accum_dtype=%s",
+               "stride=[%d,%d], dilation=[%d,%d], pad=[%d,%d,%d,%d]",
                in_batch, in_height, in_width, in_channels, f_height, f_width, f_in_channels, f_out_channels, out_batch,
                out_height, out_width, out_channels, stride_h, stride_w, dilation_h, dilation_w, pad_top,
-               pad_bottom, pad_left, pad_right, EnumNamesDType()[accum_dtype]);
+               pad_bottom, pad_left, pad_right);
 
     // GEMM-conv2d, left matrix is input, right matrix is weight
     Eigen::array<Eigen::Index, 2> im2col_input_dims;
@@ -717,7 +715,7 @@
     // reshape back to [N, H, W, C]
     this->output->getTensor() = biased_output.reshape(col2im_output_dims);
 
-    if (AccDtype == DType_INT48)
+    if (OutDtype == DType_INT48)
     {
         this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
         this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
@@ -726,8 +724,8 @@
     return GraphNode::eval();
 }
 
-template <DType InDtype, DType WeightDtype, DType AccDtype>
-OpConv3d<InDtype, WeightDtype, AccDtype>::OpConv3d(SubgraphTraverser* sgt_,
+template <DType InDtype, DType WeightDtype, DType OutDtype>
+OpConv3d<InDtype, WeightDtype, OutDtype>::OpConv3d(SubgraphTraverser* sgt_,
                                          TosaAttributeBase* attribute_,
                                          uint64_t id_)
     : GraphNode(sgt_, Op_CONV3D, id_)
@@ -738,15 +736,15 @@
     INIT_ATTRIBUTE(Conv);
 }
 
-template <DType InDtype, DType WeightDtype, DType AccDtype>
-OpConv3d<InDtype, WeightDtype, AccDtype>::~OpConv3d()
+template <DType InDtype, DType WeightDtype, DType OutDtype>
+OpConv3d<InDtype, WeightDtype, OutDtype>::~OpConv3d()
 {
     if (attribute)
         delete attribute;
 }
 
-template <DType InDtype, DType WeightDtype, DType AccDtype>
-int OpConv3d<InDtype, WeightDtype, AccDtype>::checkTensorAttributes()
+template <DType InDtype, DType WeightDtype, DType OutDtype>
+int OpConv3d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
 {
     if (validateRequiredOperands())
         return 1;
@@ -762,7 +760,7 @@
         printNodeValidationError("OpConv3d: bias tensor must be rank 1");
     }
 
-    ERROR_IF(outputs[0]->getDtype() != AccDtype,
+    ERROR_IF(outputs[0]->getDtype() != OutDtype,
                 "OpConv3d: Output data type not supported for this configuration of operator");
 
     input  = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
@@ -782,8 +780,8 @@
     return 0;
 }
 
-template <DType InDtype, DType WeightDtype, DType AccDtype>
-int OpConv3d<InDtype, WeightDtype, AccDtype>::eval()
+template <DType InDtype, DType WeightDtype, DType OutDtype>
+int OpConv3d<InDtype, WeightDtype, OutDtype>::eval()
 {
     int in_batch    = this->input->getShape()[0];
     int in_depth    = this->input->getShape()[1];
@@ -827,15 +825,13 @@
     int dilation_h     = this->attribute->dilation()[1];
     int dilation_w     = this->attribute->dilation()[2];
 
-    tosa::DType accum_dtype       = (tosa::DType)this->attribute->accum_dtype();
-
     DEBUG_INFO(
         OP,
         "perform OpConv3d, input.shape=[%d,%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d,%d], output.shape=[%d,%d,%d,%d,%d], "
-        "stride=[%d,%d,%d], dilation=[%d,%d,%d], pad=[%d,%d,%d,%d,%d,%d], accum_dtype=%s",
+        "stride=[%d,%d,%d], dilation=[%d,%d,%d], pad=[%d,%d,%d,%d,%d,%d]",
         in_batch, in_depth, in_height, in_width, in_channels, f_out_channels, f_depth, f_height, f_width, f_in_channels,
         out_batch, out_depth, out_height, out_width, out_channels, stride_d, stride_h, stride_w, dilation_d, dilation_h,
-        dilation_w, pad_d0, pad_d1, pad_top, pad_bottom, pad_left, pad_right, EnumNamesDType()[accum_dtype]);
+        dilation_w, pad_d0, pad_d1, pad_top, pad_bottom, pad_left, pad_right);
 
     Eigen::array<std::pair<int32_t, int32_t>, 5> pad;
     pad[0] = std::make_pair(0, 0);
@@ -907,7 +903,7 @@
         }
     }
 
-    if (AccDtype == DType_INT48)
+    if (OutDtype == DType_INT48)
     {
         this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
         this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
@@ -916,8 +912,8 @@
     return GraphNode::eval();
 }
 
-template <DType InDtype, DType WeightDtype, DType AccDtype>
-OpDepthwiseConv2d<InDtype, WeightDtype, AccDtype>::OpDepthwiseConv2d(SubgraphTraverser* sgt_,
+template <DType InDtype, DType WeightDtype, DType OutDtype>
+OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::OpDepthwiseConv2d(SubgraphTraverser* sgt_,
                                                            TosaAttributeBase* attribute_,
                                                            uint64_t id_)
     : GraphNode(sgt_, Op_DEPTHWISE_CONV2D, id_)
@@ -928,15 +924,15 @@
     INIT_ATTRIBUTE(Conv);
 }
 
-template <DType InDtype, DType WeightDtype, DType AccDtype>
-OpDepthwiseConv2d<InDtype, WeightDtype, AccDtype>::~OpDepthwiseConv2d()
+template <DType InDtype, DType WeightDtype, DType OutDtype>
+OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::~OpDepthwiseConv2d()
 {
     if (attribute)
         delete attribute;
 }
 
-template <DType InDtype, DType WeightDtype, DType AccDtype>
-int OpDepthwiseConv2d<InDtype, WeightDtype, AccDtype>::checkTensorAttributes()
+template <DType InDtype, DType WeightDtype, DType OutDtype>
+int OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
 {
     if (validateRequiredOperands())
         return 1;
@@ -952,7 +948,7 @@
         printNodeValidationError("OpDepthwiseConv2d: bias tensor must be rank 1");
     }
 
-    ERROR_IF(outputs[0]->getDtype() != AccDtype,
+    ERROR_IF(outputs[0]->getDtype() != OutDtype,
                 "OpDepthwiseConv2d: Output data type not supported for this configuration of operator");
 
     input  = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
@@ -972,8 +968,8 @@
     return 0;
 }
 
-template <DType InDtype, DType WeightDtype, DType AccDtype>
-int OpDepthwiseConv2d<InDtype, WeightDtype, AccDtype>::eval()
+template <DType InDtype, DType WeightDtype, DType OutDtype>
+int OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::eval()
 {
     int in_batch    = this->input->getShape()[0];
     int in_height   = this->input->getShape()[1];
@@ -1010,14 +1006,12 @@
     int dilation_h     = this->attribute->dilation()[0];
     int dilation_w     = this->attribute->dilation()[1];
 
-    tosa::DType accum_dtype       = (tosa::DType)this->attribute->accum_dtype();
-
     DEBUG_INFO(OP,
                "perform OpDepthwiseConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], "
-               "output.shape=[%d,%d,%d,%d], stride=[%d,%d], dilation=[%d,%d], pad=[%d,%d,%d,%d], accum_dtype=%s",
+               "output.shape=[%d,%d,%d,%d], stride=[%d,%d], dilation=[%d,%d], pad=[%d,%d,%d,%d]",
                in_batch, in_height, in_width, in_channels, f_height, f_width, f_in_channels, f_multiplier, out_batch,
                out_height, out_width, out_channels, stride_h, stride_w, dilation_h, dilation_w, pad_top,
-               pad_bottom, pad_left, pad_right, EnumNamesDType()[accum_dtype]);
+               pad_bottom, pad_left, pad_right);
 
     Eigen::array<std::pair<int32_t, int32_t>, 4> pad;
     pad[0] = std::make_pair(0, 0);
@@ -1083,7 +1077,7 @@
         }
     }
 
-    if (AccDtype == DType_INT48)
+    if (OutDtype == DType_INT48)
     {
         this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
         this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
@@ -1092,8 +1086,8 @@
     return GraphNode::eval();
 }
 
-template <DType InDtype, DType WeightDtype, DType AccDtype>
-OpFullyConnected<InDtype, WeightDtype, AccDtype>::OpFullyConnected(SubgraphTraverser* sgt_,
+template <DType InDtype, DType WeightDtype, DType OutDtype>
+OpFullyConnected<InDtype, WeightDtype, OutDtype>::OpFullyConnected(SubgraphTraverser* sgt_,
                                                          TosaAttributeBase* attribute_,
                                                          uint64_t id_)
     : GraphNode(sgt_, Op_FULLY_CONNECTED, id_)
@@ -1104,15 +1098,15 @@
     INIT_ATTRIBUTE(FullyConnected);
 }
 
-template <DType InDtype, DType WeightDtype, DType AccDtype>
-OpFullyConnected<InDtype, WeightDtype, AccDtype>::~OpFullyConnected()
+template <DType InDtype, DType WeightDtype, DType OutDtype>
+OpFullyConnected<InDtype, WeightDtype, OutDtype>::~OpFullyConnected()
 {
     if (attribute)
         delete attribute;
 }
 
-template <DType InDtype, DType WeightDtype, DType AccDtype>
-int OpFullyConnected<InDtype, WeightDtype, AccDtype>::checkTensorAttributes()
+template <DType InDtype, DType WeightDtype, DType OutDtype>
+int OpFullyConnected<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
 {
     if (validateRequiredOperands())
         return 1;
@@ -1138,7 +1132,7 @@
         return 1;
     }
 
-    ERROR_IF(outputs[0]->getDtype() != AccDtype,
+    ERROR_IF(outputs[0]->getDtype() != OutDtype,
                 "OpFullyConnected: Output data type not supported for this configuration of operator");
 
     output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
@@ -1149,8 +1143,8 @@
     return 0;
 }
 
-template <DType InDtype, DType WeightDtype, DType AccDtype>
-int OpFullyConnected<InDtype, WeightDtype, AccDtype>::eval()
+template <DType InDtype, DType WeightDtype, DType OutDtype>
+int OpFullyConnected<InDtype, WeightDtype, OutDtype>::eval()
 {
     typedef Eigen::Tensor<int, 1>::DimensionPair DimPair;
     Eigen::array<DimPair, 1> dims{ { DimPair(1, 0) } };
@@ -1177,7 +1171,7 @@
         input_val.template cast<AccEigenType>().contract(weight_val.template cast<AccEigenType>(), dims).template cast<OutEigenType>() +
             this->bias->getTensor().reshape(bias_reshape).broadcast(bias_bcast);
 
-    if (AccDtype == DType_INT48)
+    if (OutDtype == DType_INT48)
     {
         this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
         this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
@@ -1185,8 +1179,8 @@
     return GraphNode::eval();
 }
 
-template <DType Dtype, DType AccDtype>
-OpMatMul<Dtype, AccDtype>::OpMatMul(SubgraphTraverser* sgt_,
+template <DType Dtype, DType OutDtype>
+OpMatMul<Dtype, OutDtype>::OpMatMul(SubgraphTraverser* sgt_,
                           TosaAttributeBase* attribute_,
                           uint64_t id_)
     : GraphNode(sgt_, Op_MATMUL, id_)
@@ -1197,15 +1191,15 @@
     INIT_ATTRIBUTE(MatMul);
 }
 
-template <DType Dtype, DType AccDtype>
-OpMatMul<Dtype, AccDtype>::~OpMatMul()
+template <DType Dtype, DType OutDtype>
+OpMatMul<Dtype, OutDtype>::~OpMatMul()
 {
     if (attribute)
         delete attribute;
 }
 
-template <DType Dtype, DType AccDtype>
-int OpMatMul<Dtype, AccDtype>::checkTensorAttributes()
+template <DType Dtype, DType OutDtype>
+int OpMatMul<Dtype, OutDtype>::checkTensorAttributes()
 {
     if (validateRequiredOperands())
         return 1;
@@ -1215,7 +1209,7 @@
         return 1;
     }
 
-    ERROR_IF(outputs[0]->getDtype() != AccDtype,
+    ERROR_IF(outputs[0]->getDtype() != OutDtype,
                 "OpMatMul: Output data type not supported for this configuration of operator");
 
     a      = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
@@ -1266,8 +1260,8 @@
     return 0;
 }
 
-template <DType Dtype, DType AccDtype>
-int OpMatMul<Dtype, AccDtype>::eval()
+template <DType Dtype, DType OutDtype>
+int OpMatMul<Dtype, OutDtype>::eval()
 {
     typedef Eigen::Tensor<int, 1>::DimensionPair DimPair;
     Eigen::array<DimPair, 1> dims{ { DimPair(1, 0) } };
@@ -1312,7 +1306,7 @@
         }
     }
 
-    if (AccDtype == DType_INT48)
+    if (OutDtype == DType_INT48)
     {
         this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
         this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
@@ -1587,8 +1581,8 @@
     return GraphNode::eval();
 }
 
-template <DType InDtype, DType WeightDtype, DType AccDtype>
-OpTransposeConv2d<InDtype, WeightDtype, AccDtype>::OpTransposeConv2d(SubgraphTraverser* sgt_,
+template <DType InDtype, DType WeightDtype, DType OutDtype>
+OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::OpTransposeConv2d(SubgraphTraverser* sgt_,
                                                            TosaAttributeBase* attribute_,
                                                            uint64_t id_)
     : GraphNode(sgt_, Op_TRANSPOSE_CONV2D, id_)
@@ -1599,15 +1593,15 @@
     INIT_ATTRIBUTE(TransposeConv);
 }
 
-template <DType InDtype, DType WeightDtype, DType AccDtype>
-OpTransposeConv2d<InDtype, WeightDtype, AccDtype>::~OpTransposeConv2d()
+template <DType InDtype, DType WeightDtype, DType OutDtype>
+OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::~OpTransposeConv2d()
 {
     if (attribute)
         delete attribute;
 }
 
-template <DType InDtype, DType WeightDtype, DType AccDtype>
-int OpTransposeConv2d<InDtype, WeightDtype, AccDtype>::checkTensorAttributes()
+template <DType InDtype, DType WeightDtype, DType OutDtype>
+int OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
 {
     if (validateRequiredOperands())
         return 1;
@@ -1617,7 +1611,7 @@
         return 1;
     }
 
-    ERROR_IF(outputs[0]->getDtype() != AccDtype,
+    ERROR_IF(outputs[0]->getDtype() != OutDtype,
                 "OpTransposeConv2d: Output data type not supported for this configuration of operator");
 
     input  = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
@@ -1701,8 +1695,8 @@
     return 0;
 }
 
-template <DType InDtype, DType WeightDtype, DType AccDtype>
-int OpTransposeConv2d<InDtype, WeightDtype, AccDtype>::eval()
+template <DType InDtype, DType WeightDtype, DType OutDtype>
+int OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::eval()
 {
     int in_batch    = this->input->getShape()[0];
     int in_height   = this->input->getShape()[1];
@@ -1729,8 +1723,6 @@
     int stride_h = this->attribute->stride()[0];
     int stride_w = this->attribute->stride()[1];
 
-    tosa::DType accum_dtype = (tosa::DType)this->attribute->accum_dtype();
-
     ERROR_IF(in_batch != out_batch, "OpTransposeConv2d: tensor batch mismatch %d != %d", in_batch, out_batch);
     ERROR_IF(f_in_channels != in_channels, "OpTransposeConv2d: tensor input channel mismatch %d != %d", f_in_channels,
              in_channels);
@@ -1741,10 +1733,10 @@
 
     DEBUG_INFO(OP,
                "perform OpTransposeConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], "
-               "output.shape=[%d,%d,%d,%d], stride=[%d,%d], out_pad=[%d,%d,%d,%d], accum_dtype=%s",
+               "output.shape=[%d,%d,%d,%d], stride=[%d,%d], out_pad=[%d,%d,%d,%d]",
                in_batch, in_height, in_width, in_channels, f_height, f_width, f_out_channels, f_in_channels,
                out_batch, out_height, out_width, out_channels, stride_h, stride_w, out_pad_top,
-               out_pad_bottom, out_pad_left, out_pad_right, EnumNamesDType()[accum_dtype]);
+               out_pad_bottom, out_pad_left, out_pad_right);
 
     TIn input_val      = this->input->getTensor();
     TWeight weight_val = this->weight->getTensor();
@@ -1803,7 +1795,7 @@
         }
     }
 
-    if (AccDtype == DType_INT48)
+    if (OutDtype == DType_INT48)
     {
         this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
         this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
@@ -1819,52 +1811,52 @@
 DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT8);
 DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16);
 
-DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, FP16, FP16);
-DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, FP16, FP32);
-DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, BF16, FP32);
-DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, FP32, FP32);
-DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, INT8, INT32);
-DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, INT16, INT32);
+DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP16, FP16);
+DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP16, FP32);
+DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, BF16, FP32);
+DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP32, FP32);
+DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, INT8, INT32);
+DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, INT16, INT32);
 
-                                          // [in_t, weight_t, acc_t]
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, FP16, FP16, FP16);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, FP16, FP16, FP32);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, BF16, BF16, FP32);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, FP32, FP32, FP32);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, INT8, INT4, INT32);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, INT8, INT8, INT32);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, INT16, INT8, INT48);
+                                // [in_t, weight_t, out_t]
+DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP16, FP16, FP16);
+DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP16, FP16, FP32);
+DEF_INSTANTIATE_THREE_TYPE(OpConv2d, BF16, BF16, FP32);
+DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP32, FP32, FP32);
+DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT8, INT4, INT32);
+DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT8, INT8, INT32);
+DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT16, INT8, INT48);
 
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, FP16, FP16, FP16);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, FP16, FP16, FP32);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, BF16, BF16, FP32);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, FP32, FP32, FP32);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, INT8, INT4, INT32);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, INT8, INT8, INT32);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, INT16, INT8, INT48);
+DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP16, FP16, FP16);
+DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP16, FP16, FP32);
+DEF_INSTANTIATE_THREE_TYPE(OpConv3d, BF16, BF16, FP32);
+DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP32, FP32, FP32);
+DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT8, INT4, INT32);
+DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT8, INT8, INT32);
+DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT16, INT8, INT48);
 
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, FP16, FP16, FP16);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, FP16, FP16, FP32);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, BF16, BF16, FP32);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, FP32, FP32, FP32);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, INT8, INT4, INT32);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, INT8, INT8, INT32);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, INT16, INT8, INT48);
+DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP16);
+DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP32);
+DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, BF16, BF16, FP32);
+DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP32, FP32, FP32);
+DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT8, INT4, INT32);
+DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT8, INT8, INT32);
+DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT16, INT8, INT48);
 
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FP16, FP16, FP16);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FP16, FP16, FP32);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, BF16, BF16, FP32);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FP32, FP32, FP32);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, INT8, INT4, INT32);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, INT8, INT8, INT32);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, INT16, INT8, INT48);
+DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP16, FP16, FP16);
+DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP16, FP16, FP32);
+DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, BF16, BF16, FP32);
+DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP32, FP32, FP32);
+DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, INT8, INT4, INT32);
+DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, INT8, INT8, INT32);
+DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, INT16, INT8, INT48);
 
-DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, INT8, INT32);
-DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, INT16, INT48);
-DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, FP16, FP16);
-DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, FP16, FP32);
-DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, BF16, FP32);
-DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, FP32, FP32);
+DEF_INSTANTIATE_TWO_TYPE(OpMatMul, INT8, INT32);
+DEF_INSTANTIATE_TWO_TYPE(OpMatMul, INT16, INT48);
+DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP16, FP16);
+DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP16, FP32);
+DEF_INSTANTIATE_TWO_TYPE(OpMatMul, BF16, FP32);
+DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP32, FP32);
 
 DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP16);
 DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, BF16);
@@ -1874,10 +1866,10 @@
 
 DEF_INSTANTIATE_ONE_TYPE(OpRFFT2d, FP32);
 
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, FP16, FP16, FP16);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, FP16, FP16, FP32);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, BF16, BF16, FP32);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, FP32, FP32, FP32);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, INT8, INT4, INT32);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, INT8, INT8, INT32);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, INT16, INT8, INT48);
+DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP16);
+DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP32);
+DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, BF16, BF16, FP32);
+DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP32, FP32, FP32);
+DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT8, INT4, INT32);
+DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT8, INT8, INT32);
+DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT16, INT8, INT48);