More ERROR_IF supports

- Also delay tensor allocation after operator being validated
  ERROR_IF can be caught first before 0 or negative dimension set the graph_status to UNPREDICTABLE
- Rescale, Argmax, FullyConnected, Matmul, Pad, Reshape, Slice, Transpose, Clamp, Concat, Equal, Greater, GreaterEqual, Table

Signed-off-by: Kevin Cheng <kevin.cheng@arm.com>
Change-Id: I4e1b3e5794fe195ce1a37e28443ae584645a3b91
diff --git a/reference_model/src/main.cpp b/reference_model/src/main.cpp
index e04a20b..0bf0697 100644
--- a/reference_model/src/main.cpp
+++ b/reference_model/src/main.cpp
@@ -84,6 +84,12 @@
         goto done;
     }
 
+    if (main_gt.allocateTensor())
+    {
+        WARNING("Failed to allocate tensor. Evaluation aborted.");
+        goto done;
+    }
+
     if (g_func_config.validate_only)
     {
         goto done;
@@ -251,9 +257,9 @@
 
             DEBUG_MED(GT, "Loading input tensor %s from filename: %s", tensor->getName().c_str(), filename);
 
-            if (tensor->allocate())
+            if (!tensor->is_allocated())
             {
-                WARNING("Fail to allocate tensor %s", tensor->getName().c_str());
+                WARNING("Tensor %s is not allocated before being initialized", tensor->getName().c_str());
                 return 1;
             }
 
diff --git a/reference_model/src/ops/activation_funcs.cc b/reference_model/src/ops/activation_funcs.cc
index 21677d5..c344bcb 100644
--- a/reference_model/src/ops/activation_funcs.cc
+++ b/reference_model/src/ops/activation_funcs.cc
@@ -25,14 +25,15 @@
 template <int Rank, DType Dtype>
 int OpClamp<Rank, Dtype>::register_fcn()
 {
-
     switch (Dtype)
     {
         case DType_FLOAT:
         {
             InEigenType min = (InEigenType)attribute->min_fp();
             InEigenType max = (InEigenType)attribute->max_fp();
-            this->fcn       = [min, max](InEigenType a) -> OutEigenType { return a <= min ? min : a >= max ? max : a; };
+            ERROR_IF(max < min, "OpClamp: max smaller than min");
+
+            this->fcn = [min, max](InEigenType a) -> OutEigenType { return a <= min ? min : a >= max ? max : a; };
         }
         break;
         case DType_INT8:
@@ -40,7 +41,8 @@
         {
             InEigenType min = (InEigenType)attribute->min_int();
             InEigenType max = (InEigenType)attribute->max_int();
-            this->fcn       = [min, max](InEigenType a) -> OutEigenType { return a <= min ? min : a >= max ? max : a; };
+            ERROR_IF(max < min, "OpClamp: max smaller than min");
+            this->fcn = [min, max](InEigenType a) -> OutEigenType { return a <= min ? min : a >= max ? max : a; };
         }
         break;
         default:
diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc
index 86326f5..f3e80f3 100644
--- a/reference_model/src/ops/data_layout.cc
+++ b/reference_model/src/ops/data_layout.cc
@@ -51,25 +51,49 @@
         printNodeValidationError("Concat operator must have at least one input tensor");
         return 1;
     }
+
+    int32_t num_inputs = inputs.size();
+
     // output and input must be the same types and rank
-    for (size_t i = 0; i < inputs.size(); i++)
+    for (int32_t i = 0; i < num_inputs; i++)
     {
         if (inputs[i]->matchRankType(*outputs[0]))
         {
-            printNodeValidationError("Concat operator input ranks and types must match");
+            printNodeValidationError("OpConcat: input ranks and types must match");
             return 1;
         }
         ins.push_back(dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[i]));
     }
 
-    out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
-
-    if (attribute->axis() < 0 || (size_t)attribute->axis() >= inputs[0]->getShape().size())
+    if (attribute->axis() < 0 || (size_t)attribute->axis() >= Rank)
     {
-        printNodeValidationError("Axis is beyond input tensor rank");
+        printNodeValidationError("OpConcat: axis is beyond output tensor rank");
         return 1;
     }
 
+    int32_t output_dim_on_axis = 0;
+    for (int32_t j = 0; j < num_inputs; j++)
+    {
+        for (int32_t i = 0; i < Rank; i++)
+        {
+            int32_t input_dim = inputs[j]->getShape()[i];
+            if (i == attribute->axis())
+            {
+                output_dim_on_axis += input_dim;
+            }
+            else if (input_dim != outputs[0]->getShape()[i])
+            {
+                printNodeValidationError("OpConcat: input dimension not matching output dimension");
+                return 1;
+            }
+        }
+    }
+
+    ERROR_IF(output_dim_on_axis == outputs[0]->getShape()[attribute->axis()],
+             "OpConcat: sum of input dimension on axis not equal to output dimension on axis");
+
+    out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+
     return 0;
 }
 
@@ -135,14 +159,13 @@
         return 1;
     }
 
-    in  = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
-    out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
-    TosaReference::TensorTemplate<ETensor2<int32_t>>* paddings =
-        dynamic_cast<TosaReference::TensorTemplate<ETensor2<int32_t>>*>(inputs[1]);
+    in       = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+    out      = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+    paddings = dynamic_cast<TosaReference::TensorTemplate<ETensor2<int32_t>>*>(inputs[1]);
 
-    for (int i = 0; i < Rank; i++)
+    if (this->qinfo && Dtype != DType_INT8)
     {
-        paddings_array[i] = std::make_pair(paddings->getTensor()(i, 0), paddings->getTensor()(i, 1));
+        ERROR_IF(this->qinfo->input_zp() != 0, "OpPad: zeropoint should be 0");
     }
 
     return 0;
@@ -151,6 +174,14 @@
 template <int Rank, DType Dtype>
 int OpPad<Rank, Dtype>::eval()
 {
+    // Move this to
+    for (int i = 0; i < Rank; i++)
+    {
+        ERROR_IF((paddings->getTensor()(i, 0) < 0) || (paddings->getTensor()(i, 1) < 0),
+                 "OpPad: padding can't be smaller than 0");
+        paddings_array[i] = std::make_pair(paddings->getTensor()(i, 0), paddings->getTensor()(i, 1));
+    }
+
     InEigenType pad_value = 0;
     if (this->qinfo)
     {
@@ -202,12 +233,20 @@
         return 1;
     }
 
+    ERROR_IF(inputs[0]->getElementCount() != outputs[0]->getElementCount(),
+             "Input tensor size does not match output tensor size");
+
     for (uint32_t d = 0; d < OutRank; d++)
     {
         if (attribute->shape()[d] == -1)
         {
             minusOneCount++;
         }
+        else
+        {
+            ERROR_IF(attribute->shape()[d] != outputs[0]->getShape()[d],
+                     "OpReshape: new_shape doesn't match output shape");
+        }
     }
 
     if (minusOneCount > 1)
@@ -358,7 +397,7 @@
     : GraphNode(sgt_, Op_SLICE, id_)
 {
     setRequiredOperands(1, 1);
-    setRequiredRank(0, 6);
+    setRequiredRank(1, 4);
 
     INIT_ATTRIBUTE(Slice);
 }
@@ -391,23 +430,20 @@
     in  = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
     out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
 
-    for (size_t i = 0; i < attribute->begin().size(); i++)
-    {
-        begin_array[i] = attribute->begin()[i];
-    }
+    ERROR_IF((int32_t)attribute->begin().size() != in->getRank(),
+             "OpSlice: begin array length needs to be rank(input)");
+    ERROR_IF((int32_t)attribute->size().size() != in->getRank(), "OpSlice: size array length needs to be rank(input)");
 
-    for (size_t i = 0; i < attribute->size().size(); i++)
+    for (int32_t i = 0; i < in->getRank(); i++)
     {
-        if (attribute->size()[i] != 0)
-        {
-            size_array[i] = attribute->size()[i];
-        }
-        else
-        {
-            // Tensorflow assigns a zero size to dimensions that are kept
-            // Eigen expects size to be the full size of the dimension
-            size_array[i] = in->getTensor().dimension(0);
-        }
+        int32_t b = attribute->begin()[i];
+        int32_t s = attribute->size()[i];
+        ERROR_IF(b < 0 || b >= in->getShape()[i], "OpSlice: start out of boundary");
+        ERROR_IF((b + s) < 0 || (b + s) > in->getShape()[i], "OpSlice: (start+size) out of boundary");
+        ERROR_IF(s <= 0, "OpSlice: output must be positive");
+        ERROR_IF(s != out->getShape()[i], "OpSlice: size doesn't match output tensor dimension");
+        begin_array[i] = b;
+        size_array[i]  = s;
     }
 
     return 0;
@@ -611,6 +647,7 @@
     for (int32_t d = 0; d < Rank; d++)
     {
         perm_array[d] = this->perm_tensor->getTensor().data()[d];
+        ERROR_IF(perm_array[d] < 0 or perm_array[d] >= Rank, "OpTranspose: index out of boundary");
     }
 
     out->getTensor() = in->getTensor().shuffle(perm_array);
diff --git a/reference_model/src/ops/data_layout.h b/reference_model/src/ops/data_layout.h
index c9c2602..9f44fc7 100644
--- a/reference_model/src/ops/data_layout.h
+++ b/reference_model/src/ops/data_layout.h
@@ -63,6 +63,7 @@
     Eigen::array<std::pair<ptrdiff_t, ptrdiff_t>, Rank> paddings_array;
     TosaReference::TensorTemplate<TIn>* in;
     TosaReference::TensorTemplate<TOut>* out;
+    TosaReference::TensorTemplate<Eigen::Tensor<int32_t, 2>>* paddings;
     TosaPadQuantInfo* qinfo;
 };
 
diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc
index 023158c..6808604 100644
--- a/reference_model/src/ops/ewise_binary.cc
+++ b/reference_model/src/ops/ewise_binary.cc
@@ -60,26 +60,16 @@
         return 1;
     }
 
-    // In some ops, only rank of input and output tensor needs to match
-    if (nodeType == Op_MUL || nodeType == Op_GREATER || nodeType == Op_EQUAL || nodeType == Op_GREATER_EQUAL)
-    {
-        if (inputs[0]->matchRank(*outputs[0]))
-        {
-            std::string err =
-                "Binary operators " + std::string(EnumNamesOp()[nodeType]) + " input and output rank must match";
-            printNodeValidationError(err.c_str());
-            return 1;
-        }
-    }
-    // Otherwise both rand/type of input and output must match
-    else if (inputs[0]->matchRankType(*outputs[0]))
+    if (inputs[0]->matchRank(*outputs[0]))
     {
         std::string err =
-            "Binary operators " + std::string(EnumNamesOp()[nodeType]) + " input and output rank and type must match";
+            "Binary operators " + std::string(EnumNamesOp()[nodeType]) + " input and output rank must match";
         printNodeValidationError(err.c_str());
         return 1;
     }
 
+    ERROR_IF(outputs[0]->getDtype() != OutDtype, "Binary operator type doesn't match");
+
     a      = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
     b      = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
     result = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
@@ -532,6 +522,7 @@
             printNodeValidationError("OpTable: Table must be INT8[256] if input is INT8");
             return 1;
         }
+        ERROR_IF(outputs[0]->getDtype() != DType_INT8, "OpTable: output tensor must be INT8");
     }
     else if (inputs[0]->getDtype() == DType_INT16)
     {
@@ -540,6 +531,7 @@
             printNodeValidationError("OpTable: Table must be INT16[513] if input is INT16");
             return 1;
         }
+        ERROR_IF(outputs[0]->getDtype() != DType_INT32, "OpTable: output tensor must be INT32");
     }
 
     in    = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc
index 118d048..be4e4aa 100644
--- a/reference_model/src/ops/tensor_ops.cc
+++ b/reference_model/src/ops/tensor_ops.cc
@@ -115,7 +115,7 @@
     : GraphNode(sgt_, Op_ARGMAX, id_)
 {
     setRequiredOperands(1, 1);
-    setRequiredRank(0, 6);
+    setRequiredRank(1, 4);
 
     INIT_ATTRIBUTE(Axis);
 }
@@ -133,14 +133,60 @@
     if (validateRequiredOperands())
         return 1;
 
-    if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
+    if (validateRequiredRank(inputs[0]))
     {
         return 1;
     }
 
+    int32_t output_rank = inputs[0]->getRank() - 1;
+    if (output_rank != outputs[0]->getRank())
+    {
+        printNodeValidationError("OpArgMax: Output rank needs to be rank(input) - 1");
+        return 1;
+    }
+
+    if (outputs[0]->getDtype() != DType_INT32)
+    {
+        printNodeValidationError("OpArgMax: Output data type not supported for this configuration of operator");
+        return 1;
+    }
+
     input  = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
     output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
 
+    if (attribute->axis() < 0 || attribute->axis() >= input->getRank())
+    {
+        printNodeValidationError("OpArgMax: Axis needs to be within [0, rank(input)]");
+        return 1;
+    }
+
+    bool shape_check = true;
+    for (int32_t i = 0; i < input->getRank(); i++)
+    {
+        if (i < attribute->axis())
+        {
+            if (input->getShape()[i] != output->getShape()[i])
+            {
+                shape_check = false;
+                break;
+            }
+        }
+        else if (i > attribute->axis())
+        {
+            if (input->getShape()[i] != output->getShape()[i - 1])
+            {
+                shape_check = false;
+                break;
+            }
+        }
+        // No need to check i == axis
+    }
+    if (!shape_check)
+    {
+        printNodeValidationError("OpArgMax: Mismatch between output shape provided and expected output shape");
+        return 1;
+    }
+
     return 0;
 }
 
@@ -411,6 +457,9 @@
         printNodeValidationError("OpConv2d: bias tensor must be rank 1");
     }
 
+    ERROR_IF(outputs[0]->getDtype() != AccDtype,
+             "OpFullyConnected: Output data type not supported for this configuration of operator");
+
     input  = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
     weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
     bias   = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
@@ -434,6 +483,18 @@
         return 1;
     }
 
+    if (this->qinfo)
+    {
+        if (InDtype != DType_INT8)
+        {
+            ERROR_IF(this->qinfo->input_zp() != 0, "OpConv2d: zeropoint only for int8_t");
+        }
+        if (WeightDtype != DType_INT8)
+        {
+            ERROR_IF(this->qinfo->weight_zp() != 0, "OpConv2d: zeropoint only for int8_t");
+        }
+    }
+
     return 0;
 }
 
@@ -603,6 +664,9 @@
         printNodeValidationError("OpConv3d: bias tensor must be rank 1");
     }
 
+    ERROR_IF(outputs[0]->getDtype() != AccDtype,
+             "OpFullyConnected: Output data type not supported for this configuration of operator");
+
     input  = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
     weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
     bias   = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
@@ -626,6 +690,18 @@
         return 1;
     }
 
+    if (this->qinfo)
+    {
+        if (InDtype != DType_INT8)
+        {
+            ERROR_IF(this->qinfo->input_zp() != 0, "OpConv3d: zeropoint only for int8_t");
+        }
+        if (WeightDtype != DType_INT8)
+        {
+            ERROR_IF(this->qinfo->weight_zp() != 0, "OpConv3d: zeropoint only for int8_t");
+        }
+    }
+
     return 0;
 }
 
@@ -798,6 +874,9 @@
         printNodeValidationError("OpDepthwiseConv2d: bias tensor must be rank 1");
     }
 
+    ERROR_IF(outputs[0]->getDtype() != AccDtype,
+             "OpFullyConnected: Output data type not supported for this configuration of operator");
+
     input  = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
     weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
     bias   = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
@@ -821,6 +900,18 @@
         return 1;
     }
 
+    if (this->qinfo)
+    {
+        if (InDtype != DType_INT8)
+        {
+            ERROR_IF(this->qinfo->input_zp() != 0, "OpDepthwiseConv2d: zeropoint only for int8_t");
+        }
+        if (WeightDtype != DType_INT8)
+        {
+            ERROR_IF(this->qinfo->weight_zp() != 0, "OpDepthwiseConv2d: zeropoint only for int8_t");
+        }
+    }
+
     return 0;
 }
 
@@ -987,8 +1078,23 @@
         return 1;
     }
 
+    ERROR_IF(outputs[0]->getDtype() != AccDtype,
+             "OpFullyConnected: Output data type not supported for this configuration of operator");
+
     output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]);
 
+    if (this->qinfo)
+    {
+        if (InDtype != DType_INT8)
+        {
+            ERROR_IF(this->qinfo->input_zp() != 0, "OpFullyConnected: zeropoint only for int8_t");
+        }
+        if (WeightDtype != DType_INT8)
+        {
+            ERROR_IF(this->qinfo->weight_zp() != 0, "OpFullyConnected: zeropoint only for int8_t");
+        }
+    }
+
     return 0;
 }
 
@@ -1059,6 +1165,9 @@
         return 1;
     }
 
+    ERROR_IF(outputs[0]->getDtype() != AccDtype,
+             "OpFullyConnected: Output data type not supported for this configuration of operator");
+
     a      = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
     b      = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
     output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]);
@@ -1101,6 +1210,12 @@
     }
     W = b->getShape()[2];
 
+    if (Dtype != DType_INT8)
+    {
+        ERROR_IF(this->qinfo->a_zp() != 0, "OpMatMul: zeropoint only for int8_t");
+        ERROR_IF(this->qinfo->b_zp() != 0, "OpMatMul: zeropoint only for int8_t");
+    }
+
     return 0;
 }
 
@@ -1291,11 +1406,11 @@
     return GraphNode::eval();
 }
 
-template <DType InDtype, DType OutDtype>
-OpTransposeConv2d<InDtype, OutDtype>::OpTransposeConv2d(SubgraphTraverser* sgt_,
-                                                        TosaAttributeBase* attribute_,
-                                                        TosaQuantInfoBase* qinfo_,
-                                                        uint64_t id_)
+template <DType InDtype, DType WeightDtype>
+OpTransposeConv2d<InDtype, WeightDtype>::OpTransposeConv2d(SubgraphTraverser* sgt_,
+                                                           TosaAttributeBase* attribute_,
+                                                           TosaQuantInfoBase* qinfo_,
+                                                           uint64_t id_)
     : GraphNode(sgt_, Op_TRANSPOSE_CONV2D, id_)
 {
     setRequiredOperands(3, 1);
@@ -1305,8 +1420,8 @@
     INIT_QINFO(Conv);
 }
 
-template <DType InDtype, DType OutDtype>
-OpTransposeConv2d<InDtype, OutDtype>::~OpTransposeConv2d()
+template <DType InDtype, DType WeightDtype>
+OpTransposeConv2d<InDtype, WeightDtype>::~OpTransposeConv2d()
 {
     if (attribute)
         delete attribute;
@@ -1314,8 +1429,8 @@
         delete qinfo;
 }
 
-template <DType InDtype, DType OutDtype>
-int OpTransposeConv2d<InDtype, OutDtype>::checkTensorAttributes()
+template <DType InDtype, DType WeightDtype>
+int OpTransposeConv2d<InDtype, WeightDtype>::checkTensorAttributes()
 {
     if (validateRequiredOperands())
         return 1;
@@ -1325,6 +1440,9 @@
         return 1;
     }
 
+    ERROR_IF(outputs[0]->getDtype() != AccDtype,
+             "OpFullyConnected: Output data type not supported for this configuration of operator");
+
     input  = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
     weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
     bias   = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
@@ -1363,11 +1481,23 @@
         }
     }
 
+    if (this->qinfo)
+    {
+        if (InDtype != DType_INT8)
+        {
+            ERROR_IF(this->qinfo->input_zp() != 0, "OpTransposeConv2d: zeropoint only for int8_t");
+        }
+        if (WeightDtype != DType_INT8)
+        {
+            ERROR_IF(this->qinfo->weight_zp() != 0, "OpTransposeConv2d: zeropoint only for int8_t");
+        }
+    }
+
     return 0;
 }
 
-template <DType InDtype, DType OutDtype>
-int OpTransposeConv2d<InDtype, OutDtype>::eval()
+template <DType InDtype, DType WeightDtype>
+int OpTransposeConv2d<InDtype, WeightDtype>::eval()
 {
     int in_batch    = this->input->getShape()[0];
     int in_height   = this->input->getShape()[1];
diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc
index 657eebf..e46ab38 100644
--- a/reference_model/src/ops/type_conversion.cc
+++ b/reference_model/src/ops/type_conversion.cc
@@ -30,7 +30,7 @@
     : GraphNode(sgt_, Op_RESCALE, id_)
 {
     setRequiredOperands(1, 1);
-    setRequiredRank(0, 6);
+    setRequiredRank(0, 4);
     INIT_ATTRIBUTE(Rescale);
 }
 
@@ -64,6 +64,30 @@
 
     ASSERT_MEM(in && out);
 
+    if ((InDtype != DType_INT8) && (InDtype != DType_UINT8) && (attribute->input_zp() != 0))
+    {
+        printNodeValidationError("OpRescale: Input DType not INT8/UINT8 and zero point not 0");
+        return 1;
+    }
+
+    if ((OutDtype != DType_INT8) && (OutDtype != DType_UINT8) && (attribute->output_zp() != 0))
+    {
+        printNodeValidationError("OpRescale: Output DType not INT8/UINT8 and zero point not 0");
+        return 1;
+    }
+
+    if (attribute->scale32() && (InDtype == DType_INT48))
+    {
+        printNodeValidationError("OpRescale: Scale set to true but input type is INT48");
+        return 1;
+    }
+
+    if ((!attribute->scale32()) && attribute->double_round())
+    {
+        printNodeValidationError("OpRescale: Scale set to false but double round set to true");
+        return 1;
+    }
+
     return 0;
 }
 
diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc
index 82de69c..36e0a63 100644
--- a/reference_model/src/subgraph_traverser.cc
+++ b/reference_model/src/subgraph_traverser.cc
@@ -14,7 +14,6 @@
 //    limitations under the License.
 
 #include "subgraph_traverser.h"
-#include <unordered_set>
 
 #ifndef SUBGRAPH_ERROR_IF
 #define SUBGRAPH_ERROR_IF(COND, fmt, ...)                                                                              \
@@ -119,9 +118,6 @@
 {
     int idx = 0;
 
-    // tensor name set which contains all the name used by operator
-    std::unordered_set<std::string> used_tensor_name_set;
-
     for (auto op : block->GetOperators())
     {
         // translated TosaSerializationOperator to GraphNode
@@ -266,6 +262,63 @@
 
     for (auto ts : block->GetTensors())
     {
+        DEBUG_INFO(GT, "Creating tensor %s", ts->GetName().c_str());
+        TosaReference::Tensor* tensor =
+            TensorFactory::newTensor(ts->GetName(), ts->GetDtype(), ts->GetShape(), ts->GetShape().size());
+
+        SUBGRAPH_ERROR_IF(!tensor, "SubgraphTraverser::initializeGraph(): Unsupported tensor name=%s, type=%s, rank=%d",
+                          ts->GetName().c_str(), EnumNamesDType()[ts->GetDtype()], (int)ts->GetShape().size());
+
+        // update this->tensors
+        addTensor(tensor);
+    }
+
+    DEBUG_INFO(GT, "Enumerating block %s graph inputs", block->GetName().c_str());
+    for (auto& input_name : block->GetInputs())
+    {
+        TosaReference::Tensor* tensor = findTensorByName(input_name);
+        DEBUG_INFO(GT, "input tensor name=%s", input_name.c_str());
+        if (tensor)
+        {
+            tensor->setIsSubgraphInput();
+            inputTensors.push_back(tensor);
+        }
+        else
+        {
+            SUBGRAPH_ERROR_IF(true, "SubgraphTraverser::initializeGraph(): Failed to find input tensor by name %s",
+                              input_name.c_str());
+        }
+    }
+
+    DEBUG_INFO(GT, "Enumerating block %s graph outputs", block->GetName().c_str());
+    for (auto& output_name : block->GetOutputs())
+    {
+        TosaReference::Tensor* tensor = findTensorByName(output_name);
+        DEBUG_INFO(GT, "output tensor name=%s\n", output_name.c_str());
+        if (tensor)
+        {
+            tensor->setIsSubgraphOutput();
+            outputTensors.push_back(tensor);
+        }
+        else
+        {
+            SUBGRAPH_ERROR_IF(true, "SubgraphTraverser::initializeGraph(): Failed to find output tensor by name %s",
+                              output_name.c_str());
+        }
+    }
+
+    if (DEBUG_ENABLED(DEBUG_VERB_HIGH, GT))
+    {
+        dumpNextNodeList(g_func_debug.func_debug_file);
+    }
+
+    return 0;
+}
+
+int SubgraphTraverser::allocateTensor()
+{
+    for (auto ts : block->GetTensors())
+    {
         // Bail out if tensor is used and any of its dimension is invalid.
         auto got = used_tensor_name_set.find(ts->GetName());
         if (got != used_tensor_name_set.end())
@@ -280,20 +333,18 @@
             }
         }
 
-        DEBUG_INFO(GT, "Creating tensor %s", ts->GetName().c_str());
-        TosaReference::Tensor* tensor =
-            TensorFactory::newTensor(ts->GetName(), ts->GetDtype(), ts->GetShape(), ts->GetShape().size());
+        TosaReference::Tensor* tensor = findTensorByName(ts->GetName());
+        SUBGRAPH_ERROR_IF(!tensor, "SubgraphTraverser::allocateTensor(): can't find tensor %s.", ts->GetName().c_str());
 
-        SUBGRAPH_ERROR_IF(!tensor, "SubgraphTraverser::initializeGraph(): Unsupported tensor name=%s, type=%s, rank=%d",
-                          ts->GetName().c_str(), EnumNamesDType()[ts->GetDtype()], (int)ts->GetShape().size());
+        DEBUG_INFO(GT, "Allocating tensor %s", tensor->getName().c_str());
+        if (tensor->allocate())
+        {
+            FATAL_ERROR("Failed to allocate tensor %s", tensor->getName().c_str());
+        }
 
         if (!ts->GetData().empty())
         {
-            if (tensor->allocate())
-            {
-                FATAL_ERROR("Failed to allocate tensor %s", tensor->getName().c_str());
-            }
-
+            DEBUG_INFO(GT, "Allocating tensor %s", tensor->getName().c_str());
             switch (ts->GetDtype())
             {
                 case DType_INT4:
@@ -361,48 +412,6 @@
                                       EnumNamesDType()[ts->GetDtype()]);
             }
         }
-
-        // update this->tensors
-        addTensor(tensor);
-    }
-
-    DEBUG_INFO(GT, "Enumerating block %s graph inputs", block->GetName().c_str());
-    for (auto& input_name : block->GetInputs())
-    {
-        TosaReference::Tensor* tensor = findTensorByName(input_name);
-        DEBUG_INFO(GT, "input tensor name=%s", input_name.c_str());
-        if (tensor)
-        {
-            tensor->setIsSubgraphInput();
-            inputTensors.push_back(tensor);
-        }
-        else
-        {
-            SUBGRAPH_ERROR_IF(true, "SubgraphTraverser::initializeGraph(): Failed to find input tensor by name %s",
-                              input_name.c_str());
-        }
-    }
-
-    DEBUG_INFO(GT, "Enumerating block %s graph outputs", block->GetName().c_str());
-    for (auto& output_name : block->GetOutputs())
-    {
-        TosaReference::Tensor* tensor = findTensorByName(output_name);
-        DEBUG_INFO(GT, "output tensor name=%s\n", output_name.c_str());
-        if (tensor)
-        {
-            tensor->setIsSubgraphOutput();
-            outputTensors.push_back(tensor);
-        }
-        else
-        {
-            SUBGRAPH_ERROR_IF(true, "SubgraphTraverser::initializeGraph(): Failed to find output tensor by name %s",
-                              output_name.c_str());
-        }
-    }
-
-    if (DEBUG_ENABLED(DEBUG_VERB_HIGH, GT))
-    {
-        dumpNextNodeList(g_func_debug.func_debug_file);
     }
 
     return 0;
diff --git a/reference_model/src/subgraph_traverser.h b/reference_model/src/subgraph_traverser.h
index 4be6c1f..d53a4c0 100644
--- a/reference_model/src/subgraph_traverser.h
+++ b/reference_model/src/subgraph_traverser.h
@@ -21,6 +21,7 @@
 #include "ops/op_factory.h"
 #include "tensor.h"
 #include "tosa_serialization_handler.h"
+#include <unordered_set>
 
 namespace TosaReference
 {
@@ -54,6 +55,7 @@
 
     int linkTensorsAndNodes();
     int validateGraph();
+    int allocateTensor();
 
     int dumpGraph(FILE* out) const;
     int dumpNextNodeList(FILE* out) const;
@@ -99,6 +101,9 @@
     // lifetime, although the list itself should only contain unique nodes.
     std::list<GraphNode*> nextNodeList;
 
+    // tensor name set which contains all the name used by operator
+    std::unordered_set<std::string> used_tensor_name_set;
+
     // Maximum number of times to evalute a node before
     // warning.
     const int MAX_EVAL_COUNT = 10000;