More ERROR_IF supports

- Also delay tensor allocation after operator being validated
  ERROR_IF can be caught first before 0 or negative dimension set the graph_status to UNPREDICTABLE
- Rescale, Argmax, FullyConnected, Matmul, Pad, Reshape, Slice, Transpose, Clamp, Concat, Equal, Greater, GreaterEqual, Table

Signed-off-by: Kevin Cheng <kevin.cheng@arm.com>
Change-Id: I4e1b3e5794fe195ce1a37e28443ae584645a3b91
diff --git a/reference_model/src/ops/activation_funcs.cc b/reference_model/src/ops/activation_funcs.cc
index 21677d5..c344bcb 100644
--- a/reference_model/src/ops/activation_funcs.cc
+++ b/reference_model/src/ops/activation_funcs.cc
@@ -25,14 +25,15 @@
 template <int Rank, DType Dtype>
 int OpClamp<Rank, Dtype>::register_fcn()
 {
-
     switch (Dtype)
     {
         case DType_FLOAT:
         {
             InEigenType min = (InEigenType)attribute->min_fp();
             InEigenType max = (InEigenType)attribute->max_fp();
-            this->fcn       = [min, max](InEigenType a) -> OutEigenType { return a <= min ? min : a >= max ? max : a; };
+            ERROR_IF(max < min, "OpClamp: max smaller than min");
+
+            this->fcn = [min, max](InEigenType a) -> OutEigenType { return a <= min ? min : a >= max ? max : a; };
         }
         break;
         case DType_INT8:
@@ -40,7 +41,8 @@
         {
             InEigenType min = (InEigenType)attribute->min_int();
             InEigenType max = (InEigenType)attribute->max_int();
-            this->fcn       = [min, max](InEigenType a) -> OutEigenType { return a <= min ? min : a >= max ? max : a; };
+            ERROR_IF(max < min, "OpClamp: max smaller than min");
+            this->fcn = [min, max](InEigenType a) -> OutEigenType { return a <= min ? min : a >= max ? max : a; };
         }
         break;
         default:
diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc
index 86326f5..f3e80f3 100644
--- a/reference_model/src/ops/data_layout.cc
+++ b/reference_model/src/ops/data_layout.cc
@@ -51,25 +51,49 @@
         printNodeValidationError("Concat operator must have at least one input tensor");
         return 1;
     }
+
+    int32_t num_inputs = inputs.size();
+
     // output and input must be the same types and rank
-    for (size_t i = 0; i < inputs.size(); i++)
+    for (int32_t i = 0; i < num_inputs; i++)
     {
         if (inputs[i]->matchRankType(*outputs[0]))
         {
-            printNodeValidationError("Concat operator input ranks and types must match");
+            printNodeValidationError("OpConcat: input ranks and types must match");
             return 1;
         }
         ins.push_back(dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[i]));
     }
 
-    out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
-
-    if (attribute->axis() < 0 || (size_t)attribute->axis() >= inputs[0]->getShape().size())
+    if (attribute->axis() < 0 || (size_t)attribute->axis() >= Rank)
     {
-        printNodeValidationError("Axis is beyond input tensor rank");
+        printNodeValidationError("OpConcat: axis is beyond output tensor rank");
         return 1;
     }
 
+    int32_t output_dim_on_axis = 0;
+    for (int32_t j = 0; j < num_inputs; j++)
+    {
+        for (int32_t i = 0; i < Rank; i++)
+        {
+            int32_t input_dim = inputs[j]->getShape()[i];
+            if (i == attribute->axis())
+            {
+                output_dim_on_axis += input_dim;
+            }
+            else if (input_dim != outputs[0]->getShape()[i])
+            {
+                printNodeValidationError("OpConcat: input dimension not matching output dimension");
+                return 1;
+            }
+        }
+    }
+
+    ERROR_IF(output_dim_on_axis == outputs[0]->getShape()[attribute->axis()],
+             "OpConcat: sum of input dimension on axis not equal to output dimension on axis");
+
+    out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+
     return 0;
 }
 
@@ -135,14 +159,13 @@
         return 1;
     }
 
-    in  = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
-    out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
-    TosaReference::TensorTemplate<ETensor2<int32_t>>* paddings =
-        dynamic_cast<TosaReference::TensorTemplate<ETensor2<int32_t>>*>(inputs[1]);
+    in       = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+    out      = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+    paddings = dynamic_cast<TosaReference::TensorTemplate<ETensor2<int32_t>>*>(inputs[1]);
 
-    for (int i = 0; i < Rank; i++)
+    if (this->qinfo && Dtype != DType_INT8)
     {
-        paddings_array[i] = std::make_pair(paddings->getTensor()(i, 0), paddings->getTensor()(i, 1));
+        ERROR_IF(this->qinfo->input_zp() != 0, "OpPad: zeropoint should be 0");
     }
 
     return 0;
@@ -151,6 +174,14 @@
 template <int Rank, DType Dtype>
 int OpPad<Rank, Dtype>::eval()
 {
+    // Move this to
+    for (int i = 0; i < Rank; i++)
+    {
+        ERROR_IF((paddings->getTensor()(i, 0) < 0) || (paddings->getTensor()(i, 1) < 0),
+                 "OpPad: padding can't be smaller than 0");
+        paddings_array[i] = std::make_pair(paddings->getTensor()(i, 0), paddings->getTensor()(i, 1));
+    }
+
     InEigenType pad_value = 0;
     if (this->qinfo)
     {
@@ -202,12 +233,20 @@
         return 1;
     }
 
+    ERROR_IF(inputs[0]->getElementCount() != outputs[0]->getElementCount(),
+             "Input tensor size does not match output tensor size");
+
     for (uint32_t d = 0; d < OutRank; d++)
     {
         if (attribute->shape()[d] == -1)
         {
             minusOneCount++;
         }
+        else
+        {
+            ERROR_IF(attribute->shape()[d] != outputs[0]->getShape()[d],
+                     "OpReshape: new_shape doesn't match output shape");
+        }
     }
 
     if (minusOneCount > 1)
@@ -358,7 +397,7 @@
     : GraphNode(sgt_, Op_SLICE, id_)
 {
     setRequiredOperands(1, 1);
-    setRequiredRank(0, 6);
+    setRequiredRank(1, 4);
 
     INIT_ATTRIBUTE(Slice);
 }
@@ -391,23 +430,20 @@
     in  = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
     out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
 
-    for (size_t i = 0; i < attribute->begin().size(); i++)
-    {
-        begin_array[i] = attribute->begin()[i];
-    }
+    ERROR_IF((int32_t)attribute->begin().size() != in->getRank(),
+             "OpSlice: begin array length needs to be rank(input)");
+    ERROR_IF((int32_t)attribute->size().size() != in->getRank(), "OpSlice: size array length needs to be rank(input)");
 
-    for (size_t i = 0; i < attribute->size().size(); i++)
+    for (int32_t i = 0; i < in->getRank(); i++)
     {
-        if (attribute->size()[i] != 0)
-        {
-            size_array[i] = attribute->size()[i];
-        }
-        else
-        {
-            // Tensorflow assigns a zero size to dimensions that are kept
-            // Eigen expects size to be the full size of the dimension
-            size_array[i] = in->getTensor().dimension(0);
-        }
+        int32_t b = attribute->begin()[i];
+        int32_t s = attribute->size()[i];
+        ERROR_IF(b < 0 || b >= in->getShape()[i], "OpSlice: start out of boundary");
+        ERROR_IF((b + s) < 0 || (b + s) > in->getShape()[i], "OpSlice: (start+size) out of boundary");
+        ERROR_IF(s <= 0, "OpSlice: output must be positive");
+        ERROR_IF(s != out->getShape()[i], "OpSlice: size doesn't match output tensor dimension");
+        begin_array[i] = b;
+        size_array[i]  = s;
     }
 
     return 0;
@@ -611,6 +647,7 @@
     for (int32_t d = 0; d < Rank; d++)
     {
         perm_array[d] = this->perm_tensor->getTensor().data()[d];
+        ERROR_IF(perm_array[d] < 0 or perm_array[d] >= Rank, "OpTranspose: index out of boundary");
     }
 
     out->getTensor() = in->getTensor().shuffle(perm_array);
diff --git a/reference_model/src/ops/data_layout.h b/reference_model/src/ops/data_layout.h
index c9c2602..9f44fc7 100644
--- a/reference_model/src/ops/data_layout.h
+++ b/reference_model/src/ops/data_layout.h
@@ -63,6 +63,7 @@
     Eigen::array<std::pair<ptrdiff_t, ptrdiff_t>, Rank> paddings_array;
     TosaReference::TensorTemplate<TIn>* in;
     TosaReference::TensorTemplate<TOut>* out;
+    TosaReference::TensorTemplate<Eigen::Tensor<int32_t, 2>>* paddings;
     TosaPadQuantInfo* qinfo;
 };
 
diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc
index 023158c..6808604 100644
--- a/reference_model/src/ops/ewise_binary.cc
+++ b/reference_model/src/ops/ewise_binary.cc
@@ -60,26 +60,16 @@
         return 1;
     }
 
-    // In some ops, only rank of input and output tensor needs to match
-    if (nodeType == Op_MUL || nodeType == Op_GREATER || nodeType == Op_EQUAL || nodeType == Op_GREATER_EQUAL)
-    {
-        if (inputs[0]->matchRank(*outputs[0]))
-        {
-            std::string err =
-                "Binary operators " + std::string(EnumNamesOp()[nodeType]) + " input and output rank must match";
-            printNodeValidationError(err.c_str());
-            return 1;
-        }
-    }
-    // Otherwise both rand/type of input and output must match
-    else if (inputs[0]->matchRankType(*outputs[0]))
+    if (inputs[0]->matchRank(*outputs[0]))
     {
         std::string err =
-            "Binary operators " + std::string(EnumNamesOp()[nodeType]) + " input and output rank and type must match";
+            "Binary operators " + std::string(EnumNamesOp()[nodeType]) + " input and output rank must match";
         printNodeValidationError(err.c_str());
         return 1;
     }
 
+    ERROR_IF(outputs[0]->getDtype() != OutDtype, "Binary operator type doesn't match");
+
     a      = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
     b      = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
     result = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
@@ -532,6 +522,7 @@
             printNodeValidationError("OpTable: Table must be INT8[256] if input is INT8");
             return 1;
         }
+        ERROR_IF(outputs[0]->getDtype() != DType_INT8, "OpTable: output tensor must be INT8");
     }
     else if (inputs[0]->getDtype() == DType_INT16)
     {
@@ -540,6 +531,7 @@
             printNodeValidationError("OpTable: Table must be INT16[513] if input is INT16");
             return 1;
         }
+        ERROR_IF(outputs[0]->getDtype() != DType_INT32, "OpTable: output tensor must be INT32");
     }
 
     in    = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc
index 118d048..be4e4aa 100644
--- a/reference_model/src/ops/tensor_ops.cc
+++ b/reference_model/src/ops/tensor_ops.cc
@@ -115,7 +115,7 @@
     : GraphNode(sgt_, Op_ARGMAX, id_)
 {
     setRequiredOperands(1, 1);
-    setRequiredRank(0, 6);
+    setRequiredRank(1, 4);
 
     INIT_ATTRIBUTE(Axis);
 }
@@ -133,14 +133,60 @@
     if (validateRequiredOperands())
         return 1;
 
-    if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
+    if (validateRequiredRank(inputs[0]))
     {
         return 1;
     }
 
+    int32_t output_rank = inputs[0]->getRank() - 1;
+    if (output_rank != outputs[0]->getRank())
+    {
+        printNodeValidationError("OpArgMax: Output rank needs to be rank(input) - 1");
+        return 1;
+    }
+
+    if (outputs[0]->getDtype() != DType_INT32)
+    {
+        printNodeValidationError("OpArgMax: Output data type not supported for this configuration of operator");
+        return 1;
+    }
+
     input  = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
     output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
 
+    if (attribute->axis() < 0 || attribute->axis() >= input->getRank())
+    {
+        printNodeValidationError("OpArgMax: Axis needs to be within [0, rank(input)]");
+        return 1;
+    }
+
+    bool shape_check = true;
+    for (int32_t i = 0; i < input->getRank(); i++)
+    {
+        if (i < attribute->axis())
+        {
+            if (input->getShape()[i] != output->getShape()[i])
+            {
+                shape_check = false;
+                break;
+            }
+        }
+        else if (i > attribute->axis())
+        {
+            if (input->getShape()[i] != output->getShape()[i - 1])
+            {
+                shape_check = false;
+                break;
+            }
+        }
+        // No need to check i == axis
+    }
+    if (!shape_check)
+    {
+        printNodeValidationError("OpArgMax: Mismatch between output shape provided and expected output shape");
+        return 1;
+    }
+
     return 0;
 }
 
@@ -411,6 +457,9 @@
         printNodeValidationError("OpConv2d: bias tensor must be rank 1");
     }
 
+    ERROR_IF(outputs[0]->getDtype() != AccDtype,
+             "OpFullyConnected: Output data type not supported for this configuration of operator");
+
     input  = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
     weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
     bias   = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
@@ -434,6 +483,18 @@
         return 1;
     }
 
+    if (this->qinfo)
+    {
+        if (InDtype != DType_INT8)
+        {
+            ERROR_IF(this->qinfo->input_zp() != 0, "OpConv2d: zeropoint only for int8_t");
+        }
+        if (WeightDtype != DType_INT8)
+        {
+            ERROR_IF(this->qinfo->weight_zp() != 0, "OpConv2d: zeropoint only for int8_t");
+        }
+    }
+
     return 0;
 }
 
@@ -603,6 +664,9 @@
         printNodeValidationError("OpConv3d: bias tensor must be rank 1");
     }
 
+    ERROR_IF(outputs[0]->getDtype() != AccDtype,
+             "OpFullyConnected: Output data type not supported for this configuration of operator");
+
     input  = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
     weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
     bias   = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
@@ -626,6 +690,18 @@
         return 1;
     }
 
+    if (this->qinfo)
+    {
+        if (InDtype != DType_INT8)
+        {
+            ERROR_IF(this->qinfo->input_zp() != 0, "OpConv3d: zeropoint only for int8_t");
+        }
+        if (WeightDtype != DType_INT8)
+        {
+            ERROR_IF(this->qinfo->weight_zp() != 0, "OpConv3d: zeropoint only for int8_t");
+        }
+    }
+
     return 0;
 }
 
@@ -798,6 +874,9 @@
         printNodeValidationError("OpDepthwiseConv2d: bias tensor must be rank 1");
     }
 
+    ERROR_IF(outputs[0]->getDtype() != AccDtype,
+             "OpFullyConnected: Output data type not supported for this configuration of operator");
+
     input  = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
     weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
     bias   = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
@@ -821,6 +900,18 @@
         return 1;
     }
 
+    if (this->qinfo)
+    {
+        if (InDtype != DType_INT8)
+        {
+            ERROR_IF(this->qinfo->input_zp() != 0, "OpDepthwiseConv2d: zeropoint only for int8_t");
+        }
+        if (WeightDtype != DType_INT8)
+        {
+            ERROR_IF(this->qinfo->weight_zp() != 0, "OpDepthwiseConv2d: zeropoint only for int8_t");
+        }
+    }
+
     return 0;
 }
 
@@ -987,8 +1078,23 @@
         return 1;
     }
 
+    ERROR_IF(outputs[0]->getDtype() != AccDtype,
+             "OpFullyConnected: Output data type not supported for this configuration of operator");
+
     output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]);
 
+    if (this->qinfo)
+    {
+        if (InDtype != DType_INT8)
+        {
+            ERROR_IF(this->qinfo->input_zp() != 0, "OpFullyConnected: zeropoint only for int8_t");
+        }
+        if (WeightDtype != DType_INT8)
+        {
+            ERROR_IF(this->qinfo->weight_zp() != 0, "OpFullyConnected: zeropoint only for int8_t");
+        }
+    }
+
     return 0;
 }
 
@@ -1059,6 +1165,9 @@
         return 1;
     }
 
+    ERROR_IF(outputs[0]->getDtype() != AccDtype,
+             "OpFullyConnected: Output data type not supported for this configuration of operator");
+
     a      = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
     b      = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
     output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]);
@@ -1101,6 +1210,12 @@
     }
     W = b->getShape()[2];
 
+    if (Dtype != DType_INT8)
+    {
+        ERROR_IF(this->qinfo->a_zp() != 0, "OpMatMul: zeropoint only for int8_t");
+        ERROR_IF(this->qinfo->b_zp() != 0, "OpMatMul: zeropoint only for int8_t");
+    }
+
     return 0;
 }
 
@@ -1291,11 +1406,11 @@
     return GraphNode::eval();
 }
 
-template <DType InDtype, DType OutDtype>
-OpTransposeConv2d<InDtype, OutDtype>::OpTransposeConv2d(SubgraphTraverser* sgt_,
-                                                        TosaAttributeBase* attribute_,
-                                                        TosaQuantInfoBase* qinfo_,
-                                                        uint64_t id_)
+template <DType InDtype, DType WeightDtype>
+OpTransposeConv2d<InDtype, WeightDtype>::OpTransposeConv2d(SubgraphTraverser* sgt_,
+                                                           TosaAttributeBase* attribute_,
+                                                           TosaQuantInfoBase* qinfo_,
+                                                           uint64_t id_)
     : GraphNode(sgt_, Op_TRANSPOSE_CONV2D, id_)
 {
     setRequiredOperands(3, 1);
@@ -1305,8 +1420,8 @@
     INIT_QINFO(Conv);
 }
 
-template <DType InDtype, DType OutDtype>
-OpTransposeConv2d<InDtype, OutDtype>::~OpTransposeConv2d()
+template <DType InDtype, DType WeightDtype>
+OpTransposeConv2d<InDtype, WeightDtype>::~OpTransposeConv2d()
 {
     if (attribute)
         delete attribute;
@@ -1314,8 +1429,8 @@
         delete qinfo;
 }
 
-template <DType InDtype, DType OutDtype>
-int OpTransposeConv2d<InDtype, OutDtype>::checkTensorAttributes()
+template <DType InDtype, DType WeightDtype>
+int OpTransposeConv2d<InDtype, WeightDtype>::checkTensorAttributes()
 {
     if (validateRequiredOperands())
         return 1;
@@ -1325,6 +1440,9 @@
         return 1;
     }
 
+    ERROR_IF(outputs[0]->getDtype() != AccDtype,
+             "OpFullyConnected: Output data type not supported for this configuration of operator");
+
     input  = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
     weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
     bias   = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
@@ -1363,11 +1481,23 @@
         }
     }
 
+    if (this->qinfo)
+    {
+        if (InDtype != DType_INT8)
+        {
+            ERROR_IF(this->qinfo->input_zp() != 0, "OpTransposeConv2d: zeropoint only for int8_t");
+        }
+        if (WeightDtype != DType_INT8)
+        {
+            ERROR_IF(this->qinfo->weight_zp() != 0, "OpTransposeConv2d: zeropoint only for int8_t");
+        }
+    }
+
     return 0;
 }
 
-template <DType InDtype, DType OutDtype>
-int OpTransposeConv2d<InDtype, OutDtype>::eval()
+template <DType InDtype, DType WeightDtype>
+int OpTransposeConv2d<InDtype, WeightDtype>::eval()
 {
     int in_batch    = this->input->getShape()[0];
     int in_height   = this->input->getShape()[1];
diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc
index 657eebf..e46ab38 100644
--- a/reference_model/src/ops/type_conversion.cc
+++ b/reference_model/src/ops/type_conversion.cc
@@ -30,7 +30,7 @@
     : GraphNode(sgt_, Op_RESCALE, id_)
 {
     setRequiredOperands(1, 1);
-    setRequiredRank(0, 6);
+    setRequiredRank(0, 4);
     INIT_ATTRIBUTE(Rescale);
 }
 
@@ -64,6 +64,30 @@
 
     ASSERT_MEM(in && out);
 
+    if ((InDtype != DType_INT8) && (InDtype != DType_UINT8) && (attribute->input_zp() != 0))
+    {
+        printNodeValidationError("OpRescale: Input DType not INT8/UINT8 and zero point not 0");
+        return 1;
+    }
+
+    if ((OutDtype != DType_INT8) && (OutDtype != DType_UINT8) && (attribute->output_zp() != 0))
+    {
+        printNodeValidationError("OpRescale: Output DType not INT8/UINT8 and zero point not 0");
+        return 1;
+    }
+
+    if (attribute->scale32() && (InDtype == DType_INT48))
+    {
+        printNodeValidationError("OpRescale: Scale set to true but input type is INT48");
+        return 1;
+    }
+
+    if ((!attribute->scale32()) && attribute->double_round())
+    {
+        printNodeValidationError("OpRescale: Scale set to false but double round set to true");
+        return 1;
+    }
+
     return 0;
 }