More ERROR_IF to check attribute for convolution ops

Signed-off-by: Kevin Cheng <kevin.cheng@arm.com>
Change-Id: I49d498dd3d4c069d8d1db07310f939268b9df4b7
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc
index 5494d77..7942a24 100644
--- a/reference_model/src/ops/tensor_ops.cc
+++ b/reference_model/src/ops/tensor_ops.cc
@@ -21,10 +21,10 @@
 using namespace Eigen;
 using namespace tosa;
 
-int check_pool2d_attribute_common(tosa::TosaPoolAttribute* attribute,
-                                  std::vector<int32_t> input_shape,
-                                  std::vector<int32_t> output_shape,
-                                  std::string& msg)
+int check_pool2d_attribute(tosa::TosaPoolAttribute* attribute,
+                           std::vector<int32_t> input_shape,
+                           std::vector<int32_t> output_shape,
+                           std::string& msg)
 {
     if (attribute->padding().size() != 4)
     {
@@ -57,7 +57,7 @@
     {
         if (i < 1)
         {
-            msg = "At least one kernel dimension is smaller than zero";
+            msg = "At least one kernel dimension is smaller than one";
             return 1;
         }
     }
@@ -66,7 +66,7 @@
     {
         if (i < 1)
         {
-            msg = "At least one stride dimension is smaller than zero";
+            msg = "At least one stride dimension is smaller than one";
             return 1;
         }
     }
@@ -102,6 +102,77 @@
     return 0;
 }
 
+int check_conv_attribute_qinfo(tosa::TosaConvAttribute* attribute,
+                               tosa::TosaConvQuantInfo* qinfo,
+                               uint32_t conv_dimension,
+                               std::vector<int32_t> input_shape,
+                               std::vector<int32_t> output_shape,
+                               DType InDtype,
+                               DType WeightDtype,
+                               std::string& msg)
+{
+    if (attribute->padding().size() != (2 * conv_dimension))
+    {
+        msg = "Illegal size for attribute padding";
+        return 1;
+    }
+
+    if (attribute->stride().size() != conv_dimension)
+    {
+        msg = "Illegal size for attribute stride";
+        return 1;
+    }
+
+    if (attribute->dilation().size() != conv_dimension)
+    {
+        msg = "Illegal size for attribute dilation";
+        return 1;
+    }
+
+    for (int32_t i : attribute->padding())
+    {
+        if (i < 0)
+        {
+            msg = "At least one pad is smaller than zero";
+            return 1;
+        }
+    }
+
+    for (int32_t i : attribute->stride())
+    {
+        if (i < 1)
+        {
+            msg = "At least one stride dimension is smaller than one";
+            return 1;
+        }
+    }
+
+    for (int32_t i : attribute->dilation())
+    {
+        if (i < 1)
+        {
+            msg = "At least one dilation dimension is smaller than one";
+            return 1;
+        }
+    }
+
+    if (qinfo)
+    {
+        if (InDtype != DType_INT8 && qinfo->input_zp() != 0)
+        {
+            msg = "zeropoint only for int8_t";
+            return 1;
+        }
+        if (WeightDtype != DType_INT8 && qinfo->weight_zp() != 0)
+        {
+            msg = "zeropoint only for int8_t";
+            return 1;
+        }
+    }
+
+    return 0;
+}
+
 template <int Rank, DType Dtype>
 OpArgMax<Rank, Dtype>::OpArgMax(SubgraphTraverser* sgt_,
                                 TosaAttributeBase* attribute_,
@@ -243,7 +314,7 @@
     }
 
     std::string msg;
-    if (check_pool2d_attribute_common(attribute, in->getShape(), out->getShape(), msg))
+    if (check_pool2d_attribute(attribute, in->getShape(), out->getShape(), msg))
     {
         msg = "OpAvgPool2d: " + msg;
         printNodeValidationError(msg.c_str());
@@ -460,36 +531,15 @@
     bias   = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
     output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]);
 
-    if (attribute->padding().size() != 4)
+    std::string msg;
+    if (check_conv_attribute_qinfo(attribute, qinfo, 2 /* conv_dimension */, input->getShape(), output->getShape(),
+                                   InDtype, WeightDtype, msg))
     {
-        printNodeValidationError("OpConv2d: illegal size for attribute padding");
+        msg = "OpConv2d: " + msg;
+        printNodeValidationError(msg.c_str());
         return 1;
     }
 
-    if (attribute->stride().size() != 2)
-    {
-        printNodeValidationError("OpConv2d: illegal size for attribute stride");
-        return 1;
-    }
-
-    if (attribute->dilation().size() != 2)
-    {
-        printNodeValidationError("OpConv2d: illegal size for attribute dilation");
-        return 1;
-    }
-
-    if (this->qinfo)
-    {
-        if (InDtype != DType_INT8)
-        {
-            ERROR_IF(this->qinfo->input_zp() != 0, "OpConv2d: zeropoint only for int8_t");
-        }
-        if (WeightDtype != DType_INT8)
-        {
-            ERROR_IF(this->qinfo->weight_zp() != 0, "OpConv2d: zeropoint only for int8_t");
-        }
-    }
-
     return 0;
 }
 
@@ -667,36 +717,15 @@
     bias   = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
     output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]);
 
-    if (attribute->padding().size() != 6)
+    std::string msg;
+    if (check_conv_attribute_qinfo(attribute, qinfo, 3 /* conv_dimension */, input->getShape(), output->getShape(),
+                                   InDtype, WeightDtype, msg))
     {
-        printNodeValidationError("OpConv3d: illegal size for attribute padding");
+        msg = "OpConv3d: " + msg;
+        printNodeValidationError(msg.c_str());
         return 1;
     }
 
-    if (attribute->stride().size() != 3)
-    {
-        printNodeValidationError("OpConv3d: illegal size for attribute stride");
-        return 1;
-    }
-
-    if (attribute->dilation().size() != 3)
-    {
-        printNodeValidationError("OpConv3d: illegal size for attribute dilation");
-        return 1;
-    }
-
-    if (this->qinfo)
-    {
-        if (InDtype != DType_INT8)
-        {
-            ERROR_IF(this->qinfo->input_zp() != 0, "OpConv3d: zeropoint only for int8_t");
-        }
-        if (WeightDtype != DType_INT8)
-        {
-            ERROR_IF(this->qinfo->weight_zp() != 0, "OpConv3d: zeropoint only for int8_t");
-        }
-    }
-
     return 0;
 }
 
@@ -877,36 +906,15 @@
     bias   = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
     output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]);
 
-    if (attribute->padding().size() != 4)
+    std::string msg;
+    if (check_conv_attribute_qinfo(attribute, qinfo, 2 /* conv_dimension */, input->getShape(), output->getShape(),
+                                   InDtype, WeightDtype, msg))
     {
-        printNodeValidationError("OpDepthwiseConv2d: illegal size for attribute padding");
+        msg = "OpDepthwiseConv2d: " + msg;
+        printNodeValidationError(msg.c_str());
         return 1;
     }
 
-    if (attribute->stride().size() != 2)
-    {
-        printNodeValidationError("OpDepthwiseConv2d: illegal size for attribute stride");
-        return 1;
-    }
-
-    if (attribute->dilation().size() != 2)
-    {
-        printNodeValidationError("OpDepthwiseConv2d: illegal size for attribute dilation");
-        return 1;
-    }
-
-    if (this->qinfo)
-    {
-        if (InDtype != DType_INT8)
-        {
-            ERROR_IF(this->qinfo->input_zp() != 0, "OpDepthwiseConv2d: zeropoint only for int8_t");
-        }
-        if (WeightDtype != DType_INT8)
-        {
-            ERROR_IF(this->qinfo->weight_zp() != 0, "OpDepthwiseConv2d: zeropoint only for int8_t");
-        }
-    }
-
     return 0;
 }
 
@@ -1310,7 +1318,7 @@
     out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
 
     std::string msg;
-    if (check_pool2d_attribute_common(attribute, in->getShape(), out->getShape(), msg))
+    if (check_pool2d_attribute(attribute, in->getShape(), out->getShape(), msg))
     {
         msg = "OpMaxPool2d: " + msg;
         printNodeValidationError(msg.c_str());
@@ -1467,6 +1475,33 @@
         return 1;
     }
 
+    for (int32_t i : attribute->outpad())
+    {
+        if (i < 0)
+        {
+            printNodeValidationError("OpTransposeConv2d: At least one pad is smaller than zero");
+            return 1;
+        }
+    }
+
+    for (int32_t i : attribute->stride())
+    {
+        if (i < 1)
+        {
+            printNodeValidationError("OpTransposeConv2d: At least one stride is smaller than one");
+            return 1;
+        }
+    }
+
+    for (int32_t i : attribute->dilation())
+    {
+        if (i < 1)
+        {
+            printNodeValidationError("OpTransposeConv2d: At least one dilation is smaller than one");
+            return 1;
+        }
+    }
+
     for (int d = 0; d < 4; d++)
     {
         if (attribute->output_shape()[d] != this->output->getShape()[d])