[reference model] Add precise mode

This adds --precise_mode=1 option to tosa_referece_model,
which will cause reference model to convert all floating point tensors
to FP64 tensors and compute all operators accordingly.

Also adds optional -p arguments to test runners tosa_verif_run_tests.py
and tosa_verif_framework_compiler_runner.py to run tests in precise mode

Signed-off-by: Tai Ly <tai.ly@arm.com>
Change-Id: I156055216ad61710096497a8fa1a653be2a602a3
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc
index b3845df..f8fd323 100644
--- a/reference_model/src/ops/tensor_ops.cc
+++ b/reference_model/src/ops/tensor_ops.cc
@@ -116,14 +116,14 @@
 }
 
 int check_conv_attribute(tosa::TosaConvAttribute* attribute,
-                               uint32_t conv_dimension,
-                               std::vector<int32_t> input_shape,
-                               std::vector<int32_t> output_shape,
-                               std::vector<int32_t> weights,
-                               uint32_t offset_kernel,
-                               DType InDtype,
-                               DType WeightDtype,
-                               std::string& msg)
+                         uint32_t conv_dimension,
+                         std::vector<int32_t> input_shape,
+                         std::vector<int32_t> output_shape,
+                         std::vector<int32_t> weights,
+                         uint32_t offset_kernel,
+                         TOSA_REF_TYPE InDtype,
+                         TOSA_REF_TYPE WeightDtype,
+                         std::string& msg)
 {
     if (attribute->pad().size() != (2 * conv_dimension))
     {
@@ -226,11 +226,13 @@
         return 1;
     }
 
-    if (InDtype != DType_INT8 && attribute->input_zp() != 0) {
+    if (InDtype != TOSA_REF_TYPE_INT8 && attribute->input_zp() != 0)
+    {
         msg = "Input zero point must be zero for non-int8 data";
         return 1;
     }
-    if (WeightDtype != DType_INT8 && attribute->weight_zp() != 0) {
+    if (WeightDtype != TOSA_REF_TYPE_INT8 && attribute->weight_zp() != 0)
+    {
         msg = "Weight zero point must be zero for non-int8 data";
         return 1;
     }
@@ -318,7 +320,7 @@
     return 0;
 }
 
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
 OpArgMax<Rank, Dtype>::OpArgMax(SubgraphTraverser* sgt_,
                                 TosaAttributeBase* attribute_,
                                 uint64_t id_)
@@ -330,14 +332,14 @@
     INIT_ATTRIBUTE(Axis);
 }
 
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
 OpArgMax<Rank, Dtype>::~OpArgMax()
 {
     if (attribute)
         delete attribute;
 }
 
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
 int OpArgMax<Rank, Dtype>::checkTensorAttributes()
 {
     if (validateRequiredOperands())
@@ -355,7 +357,7 @@
         return 1;
     }
 
-    if (outputs[0]->getDtype() != DType_INT32)
+    if (outputs[0]->getDtype() != TOSA_REF_TYPE_INT32)
     {
         printNodeValidationError("OpArgMax: Output data type not supported for this configuration of operator");
         return 1;
@@ -400,7 +402,7 @@
     return 0;
 }
 
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
 int OpArgMax<Rank, Dtype>::eval()
 {
     Eigen::Tensor<DenseIndex, Rank - 1> index = this->input->getTensor().argmax(attribute->axis());
@@ -410,7 +412,7 @@
     return GraphNode::eval();
 }
 
-template <DType Dtype, DType AccDtype>
+template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE AccDtype>
 OpAvgPool2d<Dtype, AccDtype>::OpAvgPool2d(SubgraphTraverser* sgt_,
                                 TosaAttributeBase* attribute_,
                                 uint64_t id_)
@@ -422,14 +424,14 @@
     INIT_ATTRIBUTE(Pool);
 }
 
-template <DType Dtype, DType AccDtype>
+template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE AccDtype>
 OpAvgPool2d<Dtype, AccDtype>::~OpAvgPool2d()
 {
     if (attribute)
         delete attribute;
 }
 
-template <DType Dtype, DType AccDtype>
+template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE AccDtype>
 int OpAvgPool2d<Dtype, AccDtype>::checkTensorAttributes()
 {
     if (validateRequiredOperands())
@@ -449,8 +451,10 @@
     in  = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
     out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
 
-    ERROR_IF(Dtype != DType_INT8 && attribute->input_zp() != 0, "OpAvgPool2d: Input zeropoint must be zero for non int8_t data");
-    ERROR_IF(Dtype != DType_INT8 && attribute->output_zp() != 0, "OpAvgPool2d: Output zeropoint must be zero for non int8_t data");
+    ERROR_IF(Dtype != TOSA_REF_TYPE_INT8 && attribute->input_zp() != 0,
+             "OpAvgPool2d: Input zeropoint must be zero for non int8_t data");
+    ERROR_IF(Dtype != TOSA_REF_TYPE_INT8 && attribute->output_zp() != 0,
+             "OpAvgPool2d: Output zeropoint must be zero for non int8_t data");
 
     std::string msg;
     if (check_pool2d_attribute(attribute, in->getShape(), out->getShape(), msg))
@@ -466,8 +470,9 @@
 // This calculates the number of padding elements used for each location along an axis
 // Average pooling only divides by the number of elements used, not including padding.
 // This function uses left/right, but is also used for vertical padding with top/bottom
-template <DType Dtype, DType AccDtype>
-ETensor1<int32_t> OpAvgPool2d<Dtype, AccDtype>::calculate_div_map_1d(int in_size, int out_size, int kernel_size, int stride, int32_t pad_left, int32_t pad_right)
+template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE AccDtype>
+ETensor1<int32_t> OpAvgPool2d<Dtype, AccDtype>::calculate_div_map_1d(
+    int in_size, int out_size, int kernel_size, int stride, int32_t pad_left, int32_t pad_right)
 {
     ETensor1<int32_t> result(out_size);
 
@@ -495,7 +500,7 @@
 
 // assuming input and output tensor have same scales like tflite reference
 // so no need to scale input and output
-template <DType Dtype, DType AccDtype>
+template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE AccDtype>
 int OpAvgPool2d<Dtype, AccDtype>::eval()
 {
     int in_batch    = this->in->getShape()[0];
@@ -531,7 +536,7 @@
     LEVEL_CHECK(pad_left <= tosa_level.MAX_KERNEL, "pad_left should be smaller than or equal to MAX_KERNEL");
     LEVEL_CHECK(pad_right <= tosa_level.MAX_KERNEL, "pad_right should be smaller than or equal to MAX_KERNEL");
 
-    tosa::DType accum_dtype       = (tosa::DType)this->attribute->accum_dtype();
+    TOSA_REF_TYPE accum_dtype = ConvertDType(this->attribute->accum_dtype());
 
     DEBUG_INFO(OP,
                "perform AvgPool2d, input.shape=[%d,%d,%d,%d], output.shape=[%d,%d,%d,%d], kernel=[%d,%d], "
@@ -556,7 +561,7 @@
     pad[3] = std::make_pair(0, 0);
 
     ETensor4<InEigenType> input_val = this->in->getTensor();
-    if (Dtype == DType_INT8)
+    if (Dtype == TOSA_REF_TYPE_INT8)
     {
         input_val = input_val - (InEigenType)attribute->input_zp();
     }
@@ -604,7 +609,8 @@
         dm2_h.contract(dm2_w, contract_dims)
             .reshape(Eigen::array<Eigen::Index, 4>{ 1, out_height, out_width, 1 })
             .broadcast(bcast);
-    if (Dtype != DType_FP32 && Dtype != DType_FP16 && Dtype != DType_BF16)
+    if (Dtype != TOSA_REF_TYPE_FP32 && Dtype != TOSA_REF_TYPE_FP16 && Dtype != TOSA_REF_TYPE_BF16 &&
+        Dtype != TOSA_REF_TYPE_FP64)
     {
         try
         {
@@ -632,7 +638,7 @@
     return GraphNode::eval();
 }
 
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
 OpConv2d<InDtype, WeightDtype, OutDtype>::OpConv2d(SubgraphTraverser* sgt_,
                                          TosaAttributeBase* attribute_,
                                          uint64_t id_)
@@ -644,14 +650,14 @@
     INIT_ATTRIBUTE(Conv);
 }
 
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
 OpConv2d<InDtype, WeightDtype, OutDtype>::~OpConv2d()
 {
     if (attribute)
         delete attribute;
 }
 
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
 int OpConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
 {
     if (validateRequiredOperands())
@@ -688,7 +694,7 @@
     return 0;
 }
 
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
 int OpConv2d<InDtype, WeightDtype, OutDtype>::eval()
 {
     int in_batch    = this->input->getShape()[0];
@@ -781,7 +787,7 @@
 
     TIn input_val      = this->input->getTensor();
     TWeight weight_val = this->weight->getTensor();
-    if (InDtype == DType_INT8 || WeightDtype == DType_INT8)
+    if (InDtype == TOSA_REF_TYPE_INT8 || WeightDtype == TOSA_REF_TYPE_INT8)
     {
         input_val  = input_val - (InEigenType)attribute->input_zp();
         weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
@@ -817,7 +823,7 @@
     // reshape back to [N, H, W, C]
     this->output->getTensor() = biased_output.reshape(col2im_output_dims);
 
-    if (OutDtype == DType_INT48)
+    if (OutDtype == TOSA_REF_TYPE_INT48)
     {
         this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
         this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
@@ -826,7 +832,7 @@
     return GraphNode::eval();
 }
 
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
 OpConv3d<InDtype, WeightDtype, OutDtype>::OpConv3d(SubgraphTraverser* sgt_,
                                          TosaAttributeBase* attribute_,
                                          uint64_t id_)
@@ -838,14 +844,14 @@
     INIT_ATTRIBUTE(Conv);
 }
 
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
 OpConv3d<InDtype, WeightDtype, OutDtype>::~OpConv3d()
 {
     if (attribute)
         delete attribute;
 }
 
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
 int OpConv3d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
 {
     if (validateRequiredOperands())
@@ -882,7 +888,7 @@
     return 0;
 }
 
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
 int OpConv3d<InDtype, WeightDtype, OutDtype>::eval()
 {
     int in_batch    = this->input->getShape()[0];
@@ -959,7 +965,7 @@
 
     TIn input_val      = this->input->getTensor();
     TWeight weight_val = this->weight->getTensor();
-    if (InDtype == DType_INT8 || WeightDtype == DType_INT8)
+    if (InDtype == TOSA_REF_TYPE_INT8 || WeightDtype == TOSA_REF_TYPE_INT8)
     {
         input_val  = input_val - (InEigenType)attribute->input_zp();
         weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
@@ -1020,7 +1026,7 @@
         }
     }
 
-    if (OutDtype == DType_INT48)
+    if (OutDtype == TOSA_REF_TYPE_INT48)
     {
         this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
         this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
@@ -1029,10 +1035,10 @@
     return GraphNode::eval();
 }
 
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
 OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::OpDepthwiseConv2d(SubgraphTraverser* sgt_,
-                                                           TosaAttributeBase* attribute_,
-                                                           uint64_t id_)
+                                                                     TosaAttributeBase* attribute_,
+                                                                     uint64_t id_)
     : GraphNode(sgt_, Op_DEPTHWISE_CONV2D, id_)
 {
     setRequiredOperands(3, 1);
@@ -1041,14 +1047,14 @@
     INIT_ATTRIBUTE(Conv);
 }
 
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
 OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::~OpDepthwiseConv2d()
 {
     if (attribute)
         delete attribute;
 }
 
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
 int OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
 {
     if (validateRequiredOperands())
@@ -1085,7 +1091,7 @@
     return 0;
 }
 
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
 int OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::eval()
 {
     int in_batch    = this->input->getShape()[0];
@@ -1149,7 +1155,7 @@
 
     TIn input_val      = this->input->getTensor();
     TWeight weight_val = this->weight->getTensor();
-    if (InDtype == DType_INT8 || WeightDtype == DType_INT8)
+    if (InDtype == TOSA_REF_TYPE_INT8 || WeightDtype == TOSA_REF_TYPE_INT8)
     {
         input_val  = input_val - (InEigenType)attribute->input_zp();
         weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
@@ -1205,7 +1211,7 @@
         }
     }
 
-    if (OutDtype == DType_INT48)
+    if (OutDtype == TOSA_REF_TYPE_INT48)
     {
         this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
         this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
@@ -1214,10 +1220,10 @@
     return GraphNode::eval();
 }
 
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
 OpFullyConnected<InDtype, WeightDtype, OutDtype>::OpFullyConnected(SubgraphTraverser* sgt_,
-                                                         TosaAttributeBase* attribute_,
-                                                         uint64_t id_)
+                                                                   TosaAttributeBase* attribute_,
+                                                                   uint64_t id_)
     : GraphNode(sgt_, Op_FULLY_CONNECTED, id_)
 {
     setRequiredOperands(3, 1);
@@ -1226,14 +1232,14 @@
     INIT_ATTRIBUTE(FullyConnected);
 }
 
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
 OpFullyConnected<InDtype, WeightDtype, OutDtype>::~OpFullyConnected()
 {
     if (attribute)
         delete attribute;
 }
 
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
 int OpFullyConnected<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
 {
     if (validateRequiredOperands())
@@ -1265,13 +1271,15 @@
 
     output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
 
-    ERROR_IF(InDtype != DType_INT8 && attribute->input_zp() != 0, "OpFullyConnected: Input zeropoint must be zero for non int8_t data");
-    ERROR_IF(WeightDtype != DType_INT8 && attribute->weight_zp() != 0, "OpFullyConnected: Weight zeropoint must be zero for non int8_t data");
+    ERROR_IF(InDtype != TOSA_REF_TYPE_INT8 && attribute->input_zp() != 0,
+             "OpFullyConnected: Input zeropoint must be zero for non int8_t data");
+    ERROR_IF(WeightDtype != TOSA_REF_TYPE_INT8 && attribute->weight_zp() != 0,
+             "OpFullyConnected: Weight zeropoint must be zero for non int8_t data");
 
     return 0;
 }
 
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
 int OpFullyConnected<InDtype, WeightDtype, OutDtype>::eval()
 {
     typedef Eigen::Tensor<int, 1>::DimensionPair DimPair;
@@ -1289,7 +1297,7 @@
 
     TIn input_val      = this->input->getTensor();
     TWeight weight_val = this->weight->getTensor().shuffle(weight_shuffle);
-    if (InDtype == DType_INT8 || WeightDtype == DType_INT8)
+    if (InDtype == TOSA_REF_TYPE_INT8 || WeightDtype == TOSA_REF_TYPE_INT8)
     {
         input_val  = input_val - (InEigenType)attribute->input_zp();
         weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
@@ -1299,7 +1307,7 @@
         input_val.template cast<AccEigenType>().contract(weight_val.template cast<AccEigenType>(), dims).template cast<OutEigenType>() +
             this->bias->getTensor().reshape(bias_reshape).broadcast(bias_bcast);
 
-    if (OutDtype == DType_INT48)
+    if (OutDtype == TOSA_REF_TYPE_INT48)
     {
         this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
         this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
@@ -1307,7 +1315,7 @@
     return GraphNode::eval();
 }
 
-template <DType Dtype, DType OutDtype>
+template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE OutDtype>
 OpMatMul<Dtype, OutDtype>::OpMatMul(SubgraphTraverser* sgt_,
                           TosaAttributeBase* attribute_,
                           uint64_t id_)
@@ -1319,14 +1327,14 @@
     INIT_ATTRIBUTE(MatMul);
 }
 
-template <DType Dtype, DType OutDtype>
+template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE OutDtype>
 OpMatMul<Dtype, OutDtype>::~OpMatMul()
 {
     if (attribute)
         delete attribute;
 }
 
-template <DType Dtype, DType OutDtype>
+template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE OutDtype>
 int OpMatMul<Dtype, OutDtype>::checkTensorAttributes()
 {
     if (validateRequiredOperands())
@@ -1382,13 +1390,15 @@
     }
     W = b->getShape()[2];
 
-    ERROR_IF(Dtype != DType_INT8 && attribute->a_zp() != 0, "OpMatMul: A zeropoint must be zero for non int8_t data");
-    ERROR_IF(Dtype != DType_INT8 && attribute->b_zp() != 0, "OpMatMul: B zeropoint must be zero for non int8_t data");
+    ERROR_IF(Dtype != TOSA_REF_TYPE_INT8 && attribute->a_zp() != 0,
+             "OpMatMul: A zeropoint must be zero for non int8_t data");
+    ERROR_IF(Dtype != TOSA_REF_TYPE_INT8 && attribute->b_zp() != 0,
+             "OpMatMul: B zeropoint must be zero for non int8_t data");
 
     return 0;
 }
 
-template <DType Dtype, DType OutDtype>
+template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE OutDtype>
 int OpMatMul<Dtype, OutDtype>::eval()
 {
     typedef Eigen::Tensor<int, 1>::DimensionPair DimPair;
@@ -1396,7 +1406,7 @@
 
     TIn a_val = this->a->getTensor();
     TIn b_val = this->b->getTensor();
-    if (Dtype == DType_INT8)
+    if (Dtype == TOSA_REF_TYPE_INT8)
     {
         a_val = a_val - (InEigenType)attribute->a_zp();
         b_val = b_val - (InEigenType)attribute->b_zp();
@@ -1434,7 +1444,7 @@
         }
     }
 
-    if (OutDtype == DType_INT48)
+    if (OutDtype == TOSA_REF_TYPE_INT48)
     {
         this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
         this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
@@ -1443,7 +1453,7 @@
     return GraphNode::eval();
 }
 
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
 OpMaxPool2d<Dtype>::OpMaxPool2d(SubgraphTraverser* sgt_,
                                 TosaAttributeBase* attribute_,
                                 uint64_t id_)
@@ -1455,14 +1465,14 @@
     INIT_ATTRIBUTE(Pool);
 }
 
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
 OpMaxPool2d<Dtype>::~OpMaxPool2d()
 {
     if (attribute)
         delete attribute;
 }
 
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
 int OpMaxPool2d<Dtype>::checkTensorAttributes()
 {
     if (validateRequiredOperands())
@@ -1493,7 +1503,7 @@
     return 0;
 }
 
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
 int OpMaxPool2d<Dtype>::eval()
 {
     int in_batch    = this->in->getShape()[0];
@@ -1586,10 +1596,8 @@
     return GraphNode::eval();
 }
 
-template <DType Dtype>
-OpFFT2d<Dtype>::OpFFT2d(SubgraphTraverser* sgt_,
-                        TosaAttributeBase* attribute_,
-                        uint64_t id_)
+template <TOSA_REF_TYPE Dtype>
+OpFFT2d<Dtype>::OpFFT2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
     : GraphNode(sgt_, Op_FFT2D, id_)
 {
     setRequiredOperands(2, 2);
@@ -1598,14 +1606,14 @@
     INIT_ATTRIBUTE(FFT);
 }
 
-template <DType Dtype>
-OpFFT2d<Dtype>::~OpFFT2d() {
+template <TOSA_REF_TYPE Dtype>
+OpFFT2d<Dtype>::~OpFFT2d()
+{
     if (attribute)
         delete attribute;
 }
 
-
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
 int OpFFT2d<Dtype>::checkTensorAttributes()
 {
     if (validateRequiredOperands())
@@ -1643,7 +1651,7 @@
     return 0;
 }
 
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
 int OpFFT2d<Dtype>::eval()
 {
     int in_real_batch = this->in_real->getShape()[0];
@@ -1709,7 +1717,7 @@
     return GraphNode::eval();
 }
 
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
 OpRFFT2d<Dtype>::OpRFFT2d(SubgraphTraverser* sgt_,
                           TosaAttributeBase* attribute_,
                           uint64_t id_)
@@ -1719,11 +1727,11 @@
     setRequiredRank(3);
 }
 
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
 OpRFFT2d<Dtype>::~OpRFFT2d() {}
 
 
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
 int OpRFFT2d<Dtype>::checkTensorAttributes()
 {
     if (validateRequiredOperands())
@@ -1759,7 +1767,7 @@
     return 0;
 }
 
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
 int OpRFFT2d<Dtype>::eval()
 {
     int32_t in_batch = in->getShape()[0];
@@ -1815,10 +1823,10 @@
     return GraphNode::eval();
 }
 
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
 OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::OpTransposeConv2d(SubgraphTraverser* sgt_,
-                                                           TosaAttributeBase* attribute_,
-                                                           uint64_t id_)
+                                                                     TosaAttributeBase* attribute_,
+                                                                     uint64_t id_)
     : GraphNode(sgt_, Op_TRANSPOSE_CONV2D, id_)
 {
     setRequiredOperands(3, 1);
@@ -1827,14 +1835,14 @@
     INIT_ATTRIBUTE(TransposeConv);
 }
 
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
 OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::~OpTransposeConv2d()
 {
     if (attribute)
         delete attribute;
 }
 
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
 int OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
 {
     if (validateRequiredOperands())
@@ -1923,13 +1931,15 @@
         return 1;
     }
 
-    ERROR_IF(InDtype != DType_INT8 && attribute->input_zp() != 0, "OpTransposeConv2d: Input zeropoint must be zero for non int8_t data");
-    ERROR_IF(WeightDtype != DType_INT8 && attribute->weight_zp() != 0, "OpTransposeConv2d: Weight zeropoint must be zero for non int8_t data");
+    ERROR_IF(InDtype != TOSA_REF_TYPE_INT8 && attribute->input_zp() != 0,
+             "OpTransposeConv2d: Input zeropoint must be zero for non int8_t data");
+    ERROR_IF(WeightDtype != TOSA_REF_TYPE_INT8 && attribute->weight_zp() != 0,
+             "OpTransposeConv2d: Weight zeropoint must be zero for non int8_t data");
 
     return 0;
 }
 
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
 int OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::eval()
 {
     int in_batch    = this->input->getShape()[0];
@@ -1985,7 +1995,7 @@
 
     TIn input_val      = this->input->getTensor();
     TWeight weight_val = this->weight->getTensor();
-    if (InDtype == DType_INT8 || WeightDtype == DType_INT8)
+    if (InDtype == TOSA_REF_TYPE_INT8 || WeightDtype == TOSA_REF_TYPE_INT8)
     {
         input_val  = input_val - (InEigenType)attribute->input_zp();
         weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
@@ -2040,7 +2050,7 @@
         }
     }
 
-    if (OutDtype == DType_INT48)
+    if (OutDtype == TOSA_REF_TYPE_INT48)
     {
         this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
         this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
@@ -2055,6 +2065,7 @@
 DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP32);
 DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT8);
 DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP64);
 
 DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP16, FP16);
 DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP16, FP32);
@@ -2062,6 +2073,7 @@
 DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP32, FP32);
 DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, INT8, INT32);
 DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, INT16, INT32);
+DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP64, FP64);
 
                                 // [in_t, weight_t, out_t]
 DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP16, FP16, FP16);
@@ -2071,6 +2083,7 @@
 DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT8, INT4, INT32);
 DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT8, INT8, INT32);
 DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT16, INT8, INT48);
+DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP64, FP64, FP64);
 
 DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP16, FP16, FP16);
 DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP16, FP16, FP32);
@@ -2079,6 +2092,7 @@
 DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT8, INT4, INT32);
 DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT8, INT8, INT32);
 DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT16, INT8, INT48);
+DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP64, FP64, FP64);
 
 DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP16);
 DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP32);
@@ -2087,8 +2101,10 @@
 DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT8, INT4, INT32);
 DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT8, INT8, INT32);
 DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT16, INT8, INT48);
+DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP64, FP64, FP64);
 
 DEF_INSTANTIATE_ONE_TYPE(OpFFT2d, FP32);
+DEF_INSTANTIATE_ONE_TYPE(OpFFT2d, FP64);
 
 DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP16, FP16, FP16);
 DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP16, FP16, FP32);
@@ -2097,6 +2113,7 @@
 DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, INT8, INT4, INT32);
 DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, INT8, INT8, INT32);
 DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, INT16, INT8, INT48);
+DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP64, FP64, FP64);
 
 DEF_INSTANTIATE_TWO_TYPE(OpMatMul, INT8, INT32);
 DEF_INSTANTIATE_TWO_TYPE(OpMatMul, INT16, INT48);
@@ -2104,14 +2121,17 @@
 DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP16, FP32);
 DEF_INSTANTIATE_TWO_TYPE(OpMatMul, BF16, FP32);
 DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP32, FP32);
+DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP64, FP64);
 
 DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP16);
 DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, BF16);
 DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP32);
 DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT8);
 DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT16);
+DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP64);
 
 DEF_INSTANTIATE_ONE_TYPE(OpRFFT2d, FP32);
+DEF_INSTANTIATE_ONE_TYPE(OpRFFT2d, FP64);
 
 DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP16);
 DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP32);
@@ -2120,3 +2140,4 @@
 DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT8, INT4, INT32);
 DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT8, INT8, INT32);
 DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT16, INT8, INT48);
+DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP64, FP64, FP64);