[ref model] Add acc_type to Conv Ops

This patch implements changes required by the new acc_type field in
ConvAttribute and TransposeConvAttribute

Signed-off-by: Tai Ly <tai.ly@arm.com>
Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
Change-Id: Ib13dbeec4d8920e0ddbcca02b727e7277f2c8d62
diff --git a/reference_model/src/graph_node.h b/reference_model/src/graph_node.h
index e10f132..c0dceda 100644
--- a/reference_model/src/graph_node.h
+++ b/reference_model/src/graph_node.h
@@ -41,6 +41,10 @@
 #define DEF_INSTANTIATE_THREE_TYPE(OP, DTYPE1, DTYPE2, DTYPE3)                                                         \
     template class TosaReference::OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2, TOSA_REF_TYPE_##DTYPE3>;
 
+#define DEF_INSTANTIATE_FOUR_TYPE(OP, DTYPE1, DTYPE2, DTYPE3, DTYPE4)                                                  \
+    template class TosaReference::OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2, TOSA_REF_TYPE_##DTYPE3,           \
+                                     TOSA_REF_TYPE_##DTYPE4>;
+
 #define DEF_INSTANTIATE_THREE_TYPE_RESIZE(OP, DTYPE1, DTYPE2, OP_TYPE)                                                 \
     template class TosaReference::OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2, OP_TYPE>;
 
diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc
index 0f0013c..74315d7 100644
--- a/reference_model/src/ops/op_factory.cc
+++ b/reference_model/src/ops/op_factory.cc
@@ -70,41 +70,43 @@
             DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP8E5M2, FP16);
             break;
         case Op_CONV2D:
-            DEF_FACTORY_THREE_TYPE(OpConv2d, FP16, FP16, FP16);
-            DEF_FACTORY_THREE_TYPE(OpConv2d, FP16, FP16, FP32);
-            DEF_FACTORY_THREE_TYPE(OpConv2d, BF16, BF16, FP32);
-            DEF_FACTORY_THREE_TYPE(OpConv2d, FP32, FP32, FP32);
-            DEF_FACTORY_THREE_TYPE(OpConv2d, INT8, INT4, INT32);
-            DEF_FACTORY_THREE_TYPE(OpConv2d, INT8, INT8, INT32);
-            DEF_FACTORY_THREE_TYPE(OpConv2d, INT16, INT8, INT48);
-            DEF_FACTORY_THREE_TYPE(OpConv2d, FP64, FP64, FP64);
-            DEF_FACTORY_THREE_TYPE(OpConv2d, FP64, FP64, FP64);
-            DEF_FACTORY_THREE_TYPE(OpConv2d, FP8E4M3, FP8E4M3, FP16);
-            DEF_FACTORY_THREE_TYPE(OpConv2d, FP8E5M2, FP8E5M2, FP16);
+            // OP, attr_name, in_t, w_t, acc_t, out_t
+            DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, FP16, FP16, FP16, FP16);
+            DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, FP16, FP16, FP32, FP16);
+            DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, BF16, BF16, FP32, BF16);
+            DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, FP32, FP32, FP32, FP32);
+            DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, INT8, INT4, INT32, INT32);
+            DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, INT8, INT8, INT32, INT32);
+            DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, INT16, INT8, INT48, INT48);
+            DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, FP64, FP64, FP64, FP64);
+            DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, FP8E4M3, FP8E4M3, FP16, FP16);
+            DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, FP8E5M2, FP8E5M2, FP16, FP16);
             break;
         case Op_CONV3D:
-            DEF_FACTORY_THREE_TYPE(OpConv3d, FP16, FP16, FP16);
-            DEF_FACTORY_THREE_TYPE(OpConv3d, FP16, FP16, FP32);
-            DEF_FACTORY_THREE_TYPE(OpConv3d, BF16, BF16, FP32);
-            DEF_FACTORY_THREE_TYPE(OpConv3d, FP32, FP32, FP32);
-            DEF_FACTORY_THREE_TYPE(OpConv3d, INT8, INT4, INT32);
-            DEF_FACTORY_THREE_TYPE(OpConv3d, INT8, INT8, INT32);
-            DEF_FACTORY_THREE_TYPE(OpConv3d, INT16, INT8, INT48);
-            DEF_FACTORY_THREE_TYPE(OpConv3d, FP64, FP64, FP64);
-            DEF_FACTORY_THREE_TYPE(OpConv3d, FP8E4M3, FP8E4M3, FP16);
-            DEF_FACTORY_THREE_TYPE(OpConv3d, FP8E5M2, FP8E5M2, FP16);
+            // OP, attr_name, in_t, w_t, acc_t, out_t
+            DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, FP16, FP16, FP16, FP16);
+            DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, FP16, FP16, FP32, FP16);
+            DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, BF16, BF16, FP32, BF16);
+            DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, FP32, FP32, FP32, FP32);
+            DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, INT8, INT4, INT32, INT32);
+            DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, INT8, INT8, INT32, INT32);
+            DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, INT16, INT8, INT48, INT48);
+            DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, FP64, FP64, FP64, FP64);
+            DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, FP8E4M3, FP8E4M3, FP16, FP16);
+            DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, FP8E5M2, FP8E5M2, FP16, FP16);
             break;
         case Op_DEPTHWISE_CONV2D:
-            DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP16);
-            DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP32);
-            DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, BF16, BF16, FP32);
-            DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP32, FP32, FP32);
-            DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT8, INT4, INT32);
-            DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT8, INT8, INT32);
-            DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT16, INT8, INT48);
-            DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP64, FP64, FP64);
-            DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP8E4M3, FP8E4M3, FP16);
-            DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP8E5M2, FP8E5M2, FP16);
+            // OP, attr_name, in_t, w_t, acc_t, out_t
+            DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP16, FP16, FP16, FP16);
+            DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP16, FP16, FP32, FP16);
+            DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, BF16, BF16, FP32, BF16);
+            DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP32, FP32, FP32, FP32);
+            DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, INT8, INT4, INT32, INT32);
+            DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, INT8, INT8, INT32, INT32);
+            DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, INT16, INT8, INT48, INT48);
+            DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP64, FP64, FP64, FP64);
+            DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP8E4M3, FP8E4M3, FP16, FP16);
+            DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP8E5M2, FP8E5M2, FP16, FP16);
             break;
         case Op_FFT2D:
             DEF_FACTORY_ONE_TYPE(OpFFT2d, FP32);
@@ -148,16 +150,17 @@
             DEF_FACTORY_ONE_TYPE(OpRFFT2d, FP64);
             break;
         case Op_TRANSPOSE_CONV2D:
-            DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP16);
-            DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP32);
-            DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, BF16, BF16, FP32);
-            DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP32, FP32, FP32);
-            DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, INT8, INT4, INT32);
-            DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, INT8, INT8, INT32);
-            DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, INT16, INT8, INT48);
-            DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP64, FP64, FP64);
-            DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP8E4M3, FP8E4M3, FP16);
-            DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP8E5M2, FP8E5M2, FP16);
+            // OP, attr_name, in_t, w_t, acc_t, out_t
+            DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP16, FP16, FP16, FP16);
+            DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP16, FP16, FP32, FP16);
+            DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, BF16, BF16, FP32, BF16);
+            DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP32, FP32, FP32, FP32);
+            DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, INT8, INT4, INT32, INT32);
+            DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, INT8, INT8, INT32, INT32);
+            DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, INT16, INT8, INT48, INT48);
+            DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP64, FP64, FP64, FP64);
+            DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP8E4M3, FP8E4M3, FP16, FP16);
+            DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP8E5M2, FP8E5M2, FP16, FP16);
             break;
 
         // activation_funcs
diff --git a/reference_model/src/ops/op_factory.h b/reference_model/src/ops/op_factory.h
index 1d20066..f1d1680 100644
--- a/reference_model/src/ops/op_factory.h
+++ b/reference_model/src/ops/op_factory.h
@@ -94,6 +94,14 @@
         return new OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2, TOSA_REF_TYPE_##DTYPE3>(sgt, attribute, id);     \
     }
 
+#define DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OP, ATTR_NAME, IN_DTYPE, W_DTYPE, ACC_DTYPE, OUT_DTYPE)                       \
+    if (inputDTYPE == TOSA_REF_TYPE_##IN_DTYPE && weightDTYPE == TOSA_REF_TYPE_##W_DTYPE &&                            \
+        outputDTYPE == TOSA_REF_TYPE_##OUT_DTYPE && ACCUM_FROM_ATTRIBUTE(ATTR_NAME) == TOSA_REF_TYPE_##ACC_DTYPE)      \
+    {                                                                                                                  \
+        return new OP<TOSA_REF_TYPE_##IN_DTYPE, TOSA_REF_TYPE_##W_DTYPE, TOSA_REF_TYPE_##ACC_DTYPE,                    \
+                      TOSA_REF_TYPE_##OUT_DTYPE>(sgt, attribute, id);                                                  \
+    }
+
 // Statement-expression to evaluate accumulate attribute in-place
 #define ACCUM_FROM_ATTRIBUTE(ATTRIBUTE_NAME)                                                                           \
     ({                                                                                                                 \
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc
index 7bd249b..afd20e9 100644
--- a/reference_model/src/ops/tensor_ops.cc
+++ b/reference_model/src/ops/tensor_ops.cc
@@ -586,8 +586,10 @@
     return GraphNode::eval();
 }
 
-template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
-OpConv2d<InDtype, WeightDtype, OutDtype>::OpConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype>
+OpConv2d<InDtype, WeightDtype, AccDtype, OutDtype>::OpConv2d(SubgraphTraverser* sgt_,
+                                                             TosaAttributeBase* attribute_,
+                                                             uint64_t id_)
     : GraphNode(sgt_, Op_CONV2D, id_)
 {
     setRequiredOperands(3, 1);
@@ -596,15 +598,15 @@
     INIT_ATTRIBUTE(Conv);
 }
 
-template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
-OpConv2d<InDtype, WeightDtype, OutDtype>::~OpConv2d()
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype>
+OpConv2d<InDtype, WeightDtype, AccDtype, OutDtype>::~OpConv2d()
 {
     if (attribute)
         delete attribute;
 }
 
-template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
-int OpConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype>
+int OpConv2d<InDtype, WeightDtype, AccDtype, OutDtype>::checkTensorAttributes()
 {
     if (validateRequiredOperands())
         return 1;
@@ -640,8 +642,8 @@
     return 0;
 }
 
-template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
-int OpConv2d<InDtype, WeightDtype, OutDtype>::eval()
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype>
+int OpConv2d<InDtype, WeightDtype, AccDtype, OutDtype>::eval()
 {
     int in_batch    = this->input->getShape()[0];
     int in_height   = this->input->getShape()[1];
@@ -793,8 +795,10 @@
     return GraphNode::eval();
 }
 
-template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
-OpConv3d<InDtype, WeightDtype, OutDtype>::OpConv3d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype>
+OpConv3d<InDtype, WeightDtype, AccDtype, OutDtype>::OpConv3d(SubgraphTraverser* sgt_,
+                                                             TosaAttributeBase* attribute_,
+                                                             uint64_t id_)
     : GraphNode(sgt_, Op_CONV3D, id_)
 {
     setRequiredOperands(3, 1);
@@ -803,15 +807,15 @@
     INIT_ATTRIBUTE(Conv);
 }
 
-template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
-OpConv3d<InDtype, WeightDtype, OutDtype>::~OpConv3d()
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype>
+OpConv3d<InDtype, WeightDtype, AccDtype, OutDtype>::~OpConv3d()
 {
     if (attribute)
         delete attribute;
 }
 
-template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
-int OpConv3d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype>
+int OpConv3d<InDtype, WeightDtype, AccDtype, OutDtype>::checkTensorAttributes()
 {
     if (validateRequiredOperands())
         return 1;
@@ -847,8 +851,8 @@
     return 0;
 }
 
-template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
-int OpConv3d<InDtype, WeightDtype, OutDtype>::eval()
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype>
+int OpConv3d<InDtype, WeightDtype, AccDtype, OutDtype>::eval()
 {
     int in_batch    = this->input->getShape()[0];
     int in_depth    = this->input->getShape()[1];
@@ -1008,10 +1012,10 @@
     return GraphNode::eval();
 }
 
-template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
-OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::OpDepthwiseConv2d(SubgraphTraverser* sgt_,
-                                                                     TosaAttributeBase* attribute_,
-                                                                     uint64_t id_)
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype>
+OpDepthwiseConv2d<InDtype, WeightDtype, AccDtype, OutDtype>::OpDepthwiseConv2d(SubgraphTraverser* sgt_,
+                                                                               TosaAttributeBase* attribute_,
+                                                                               uint64_t id_)
     : GraphNode(sgt_, Op_DEPTHWISE_CONV2D, id_)
 {
     setRequiredOperands(3, 1);
@@ -1020,15 +1024,15 @@
     INIT_ATTRIBUTE(Conv);
 }
 
-template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
-OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::~OpDepthwiseConv2d()
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype>
+OpDepthwiseConv2d<InDtype, WeightDtype, AccDtype, OutDtype>::~OpDepthwiseConv2d()
 {
     if (attribute)
         delete attribute;
 }
 
-template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
-int OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype>
+int OpDepthwiseConv2d<InDtype, WeightDtype, AccDtype, OutDtype>::checkTensorAttributes()
 {
     if (validateRequiredOperands())
         return 1;
@@ -1064,8 +1068,8 @@
     return 0;
 }
 
-template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
-int OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::eval()
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype>
+int OpDepthwiseConv2d<InDtype, WeightDtype, AccDtype, OutDtype>::eval()
 {
     int in_batch    = this->input->getShape()[0];
     int in_height   = this->input->getShape()[1];
@@ -1903,10 +1907,10 @@
     return GraphNode::eval();
 }
 
-template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
-OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::OpTransposeConv2d(SubgraphTraverser* sgt_,
-                                                                     TosaAttributeBase* attribute_,
-                                                                     uint64_t id_)
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype>
+OpTransposeConv2d<InDtype, WeightDtype, AccDtype, OutDtype>::OpTransposeConv2d(SubgraphTraverser* sgt_,
+                                                                               TosaAttributeBase* attribute_,
+                                                                               uint64_t id_)
     : GraphNode(sgt_, Op_TRANSPOSE_CONV2D, id_)
 {
     setRequiredOperands(3, 1);
@@ -1915,15 +1919,15 @@
     INIT_ATTRIBUTE(TransposeConv);
 }
 
-template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
-OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::~OpTransposeConv2d()
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype>
+OpTransposeConv2d<InDtype, WeightDtype, AccDtype, OutDtype>::~OpTransposeConv2d()
 {
     if (attribute)
         delete attribute;
 }
 
-template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
-int OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype>
+int OpTransposeConv2d<InDtype, WeightDtype, AccDtype, OutDtype>::checkTensorAttributes()
 {
     if (validateRequiredOperands())
         return 1;
@@ -2017,8 +2021,8 @@
     return 0;
 }
 
-template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
-int OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::eval()
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype>
+int OpTransposeConv2d<InDtype, WeightDtype, AccDtype, OutDtype>::eval()
 {
     int in_batch    = this->input->getShape()[0];
     int in_height   = this->input->getShape()[1];
@@ -2168,39 +2172,39 @@
 DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP8E4M3, FP16);
 DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP8E5M2, FP16);
 
-// [in_t, weight_t, out_t]
-DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP16, FP16, FP16);
-DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP16, FP16, FP32);
-DEF_INSTANTIATE_THREE_TYPE(OpConv2d, BF16, BF16, FP32);
-DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP32, FP32, FP32);
-DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT8, INT4, INT32);
-DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT8, INT8, INT32);
-DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT16, INT8, INT48);
-DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP64, FP64, FP64);
-DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP8E4M3, FP8E4M3, FP16);
-DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP8E5M2, FP8E5M2, FP16);
+// [in_t, weight_t, acc_t, out_t]
+DEF_INSTANTIATE_FOUR_TYPE(OpConv2d, FP16, FP16, FP16, FP16);
+DEF_INSTANTIATE_FOUR_TYPE(OpConv2d, FP16, FP16, FP32, FP16);
+DEF_INSTANTIATE_FOUR_TYPE(OpConv2d, BF16, BF16, FP32, BF16);
+DEF_INSTANTIATE_FOUR_TYPE(OpConv2d, FP32, FP32, FP32, FP32);
+DEF_INSTANTIATE_FOUR_TYPE(OpConv2d, INT8, INT4, INT32, INT32);
+DEF_INSTANTIATE_FOUR_TYPE(OpConv2d, INT8, INT8, INT32, INT32);
+DEF_INSTANTIATE_FOUR_TYPE(OpConv2d, INT16, INT8, INT48, INT48);
+DEF_INSTANTIATE_FOUR_TYPE(OpConv2d, FP64, FP64, FP64, FP64);
+DEF_INSTANTIATE_FOUR_TYPE(OpConv2d, FP8E4M3, FP8E4M3, FP16, FP16);
+DEF_INSTANTIATE_FOUR_TYPE(OpConv2d, FP8E5M2, FP8E5M2, FP16, FP16);
 
-DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP16, FP16, FP16);
-DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP16, FP16, FP32);
-DEF_INSTANTIATE_THREE_TYPE(OpConv3d, BF16, BF16, FP32);
-DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP32, FP32, FP32);
-DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT8, INT4, INT32);
-DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT8, INT8, INT32);
-DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT16, INT8, INT48);
-DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP64, FP64, FP64);
-DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP8E4M3, FP8E4M3, FP16);
-DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP8E5M2, FP8E5M2, FP16);
+DEF_INSTANTIATE_FOUR_TYPE(OpConv3d, FP16, FP16, FP16, FP16);
+DEF_INSTANTIATE_FOUR_TYPE(OpConv3d, FP16, FP16, FP32, FP16);
+DEF_INSTANTIATE_FOUR_TYPE(OpConv3d, BF16, BF16, FP32, BF16);
+DEF_INSTANTIATE_FOUR_TYPE(OpConv3d, FP32, FP32, FP32, FP32);
+DEF_INSTANTIATE_FOUR_TYPE(OpConv3d, INT8, INT4, INT32, INT32);
+DEF_INSTANTIATE_FOUR_TYPE(OpConv3d, INT8, INT8, INT32, INT32);
+DEF_INSTANTIATE_FOUR_TYPE(OpConv3d, INT16, INT8, INT48, INT48);
+DEF_INSTANTIATE_FOUR_TYPE(OpConv3d, FP64, FP64, FP64, FP64);
+DEF_INSTANTIATE_FOUR_TYPE(OpConv3d, FP8E4M3, FP8E4M3, FP16, FP16);
+DEF_INSTANTIATE_FOUR_TYPE(OpConv3d, FP8E5M2, FP8E5M2, FP16, FP16);
 
-DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP16);
-DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP32);
-DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, BF16, BF16, FP32);
-DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP32, FP32, FP32);
-DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT8, INT4, INT32);
-DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT8, INT8, INT32);
-DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT16, INT8, INT48);
-DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP64, FP64, FP64);
-DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP8E4M3, FP8E4M3, FP16);
-DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP8E5M2, FP8E5M2, FP16);
+DEF_INSTANTIATE_FOUR_TYPE(OpDepthwiseConv2d, FP16, FP16, FP16, FP16);
+DEF_INSTANTIATE_FOUR_TYPE(OpDepthwiseConv2d, FP16, FP16, FP32, FP16);
+DEF_INSTANTIATE_FOUR_TYPE(OpDepthwiseConv2d, BF16, BF16, FP32, BF16);
+DEF_INSTANTIATE_FOUR_TYPE(OpDepthwiseConv2d, FP32, FP32, FP32, FP32);
+DEF_INSTANTIATE_FOUR_TYPE(OpDepthwiseConv2d, INT8, INT4, INT32, INT32);
+DEF_INSTANTIATE_FOUR_TYPE(OpDepthwiseConv2d, INT8, INT8, INT32, INT32);
+DEF_INSTANTIATE_FOUR_TYPE(OpDepthwiseConv2d, INT16, INT8, INT48, INT48);
+DEF_INSTANTIATE_FOUR_TYPE(OpDepthwiseConv2d, FP64, FP64, FP64, FP64);
+DEF_INSTANTIATE_FOUR_TYPE(OpDepthwiseConv2d, FP8E4M3, FP8E4M3, FP16, FP16);
+DEF_INSTANTIATE_FOUR_TYPE(OpDepthwiseConv2d, FP8E5M2, FP8E5M2, FP16, FP16);
 
 DEF_INSTANTIATE_ONE_TYPE(OpFFT2d, FP32);
 DEF_INSTANTIATE_ONE_TYPE(OpFFT2d, FP64);
@@ -2238,13 +2242,14 @@
 DEF_INSTANTIATE_ONE_TYPE(OpRFFT2d, FP32);
 DEF_INSTANTIATE_ONE_TYPE(OpRFFT2d, FP64);
 
-DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP16);
-DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP32);
-DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, BF16, BF16, FP32);
-DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP32, FP32, FP32);
-DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT8, INT4, INT32);
-DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT8, INT8, INT32);
-DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT16, INT8, INT48);
-DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP64, FP64, FP64);
-DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP8E4M3, FP8E4M3, FP16);
-DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP8E5M2, FP8E5M2, FP16);
+// [in_t, weight_t, acc_t, out_t]
+DEF_INSTANTIATE_FOUR_TYPE(OpTransposeConv2d, FP16, FP16, FP16, FP16);
+DEF_INSTANTIATE_FOUR_TYPE(OpTransposeConv2d, FP16, FP16, FP32, FP16);
+DEF_INSTANTIATE_FOUR_TYPE(OpTransposeConv2d, BF16, BF16, FP32, BF16);
+DEF_INSTANTIATE_FOUR_TYPE(OpTransposeConv2d, FP32, FP32, FP32, FP32);
+DEF_INSTANTIATE_FOUR_TYPE(OpTransposeConv2d, INT8, INT4, INT32, INT32);
+DEF_INSTANTIATE_FOUR_TYPE(OpTransposeConv2d, INT8, INT8, INT32, INT32);
+DEF_INSTANTIATE_FOUR_TYPE(OpTransposeConv2d, INT16, INT8, INT48, INT48);
+DEF_INSTANTIATE_FOUR_TYPE(OpTransposeConv2d, FP64, FP64, FP64, FP64);
+DEF_INSTANTIATE_FOUR_TYPE(OpTransposeConv2d, FP8E4M3, FP8E4M3, FP16, FP16);
+DEF_INSTANTIATE_FOUR_TYPE(OpTransposeConv2d, FP8E5M2, FP8E5M2, FP16, FP16);
diff --git a/reference_model/src/ops/tensor_ops.h b/reference_model/src/ops/tensor_ops.h
index e2bb811..2e65548 100644
--- a/reference_model/src/ops/tensor_ops.h
+++ b/reference_model/src/ops/tensor_ops.h
@@ -75,7 +75,7 @@
         int in_size, int out_size, int kernel_size, int stride, int32_t padding_left, int32_t padding_right);
 };
 
-template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype>
 class OpConv2d : public GraphNode
 {
 public:
@@ -87,7 +87,7 @@
 
     using InEigenType     = typename GetEigenType<InDtype>::type;
     using WeightEigenType = typename GetEigenType<WeightDtype>::type;
-    using AccEigenType    = typename GetAccEigenType<OutDtype>::type;    // Note: different from GetEigenType
+    using AccEigenType    = typename GetAccEigenType<AccDtype>::type;    // Note: different from GetEigenType
     using OutEigenType    = typename GetEigenType<OutDtype>::type;
     using TIn             = Eigen::Tensor<InEigenType, 4>;
     using TWeight         = Eigen::Tensor<WeightEigenType, 4>;
@@ -105,7 +105,7 @@
     tosa::TosaConvAttribute* attribute;
 };
 
-template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype>
 class OpConv3d : public GraphNode
 {
 public:
@@ -117,7 +117,7 @@
 
     using InEigenType     = typename GetEigenType<InDtype>::type;
     using WeightEigenType = typename GetEigenType<WeightDtype>::type;
-    using AccEigenType    = typename GetAccEigenType<OutDtype>::type;    // Note: different from GetEigenType
+    using AccEigenType    = typename GetAccEigenType<AccDtype>::type;    // Note: different from GetEigenType
     using OutEigenType    = typename GetEigenType<OutDtype>::type;
     using TIn             = Eigen::Tensor<InEigenType, 5>;
     using TWeight         = Eigen::Tensor<WeightEigenType, 5>;
@@ -135,7 +135,7 @@
     tosa::TosaConvAttribute* attribute;
 };
 
-template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype>
 class OpDepthwiseConv2d : public GraphNode
 {
 public:
@@ -147,7 +147,7 @@
 
     using InEigenType     = typename GetEigenType<InDtype>::type;
     using WeightEigenType = typename GetEigenType<WeightDtype>::type;
-    using AccEigenType    = typename GetAccEigenType<OutDtype>::type;    // Note: different from GetEigenType
+    using AccEigenType    = typename GetAccEigenType<AccDtype>::type;    // Note: different from GetEigenType
     using OutEigenType    = typename GetEigenType<OutDtype>::type;
     using TIn             = Eigen::Tensor<InEigenType, 4>;
     using TWeight         = Eigen::Tensor<WeightEigenType, 4>;
@@ -294,7 +294,7 @@
     tosa::TosaRFFTAttribute* attribute;
 };
 
-template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype>
 class OpTransposeConv2d : public GraphNode
 {
 public:
@@ -306,7 +306,7 @@
 
     using InEigenType     = typename GetEigenType<InDtype>::type;
     using WeightEigenType = typename GetEigenType<WeightDtype>::type;
-    using AccEigenType    = typename GetAccEigenType<OutDtype>::type;    // Note: different from GetEigenType
+    using AccEigenType    = typename GetAccEigenType<AccDtype>::type;    // Note: different from GetEigenType
     using OutEigenType    = typename GetEigenType<OutDtype>::type;
     using TIn             = Eigen::Tensor<InEigenType, 4>;
     using TWeight         = Eigen::Tensor<WeightEigenType, 4>;