[MLBEDSW-1996] Update supported operator checks

Updated supported operator checks according to latest requirements.

Change-Id: I79708d8039e464e39818d3c09e61f3f533e96f3d
Signed-off-by: Fredrik Svedberg <fredrik.svedberg@arm.com>
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py
index 574b3a4..ce3fa60 100644
--- a/ethosu/vela/supported_operators.py
+++ b/ethosu/vela/supported_operators.py
@@ -15,7 +15,7 @@
 # limitations under the License.
 # Description:
 # The SupportedOperators class which is a collection of all supported operators and parameter checks.
-from .data_type import BaseType
+from .data_type import BaseType, DataType
 
 
 class SupportedOperators:
@@ -45,9 +45,9 @@
             | set(("ResizeBilinear",))
         )
         self.unary_elem_wise_main_ops = set(("LeakyRelu", "Abs"))
-        self.binary_elem_wise_main_ops = set(
+        self.binary_elem_wise_min_max_ops = set(("Minimum", "Maximum"))
+        self.binary_elem_wise_add_mul_sub = set(
             (
-                # binary element-wise
                 "AddAct",
                 "MulAct",
                 "SubAct",
@@ -57,10 +57,9 @@
                 "Mul",
                 "Add",
                 "Sub",
-                "Minimum",
-                "Maximum",
             )
         )
+        self.binary_elem_wise_main_ops = self.binary_elem_wise_min_max_ops | self.binary_elem_wise_add_mul_sub
         self.elem_wise_main_ops = self.binary_elem_wise_main_ops | self.unary_elem_wise_main_ops
         self.activation_ops = set(
             ("QuantizedRelu", "QuantizedRelu1", "QuantizedRelu6", "Relu", "Relu6", "ReluN1To1", "Sigmoid", "Tanh")
@@ -124,7 +123,7 @@
         for t in tensors:
             if not (t.dtype.type & BaseType.Int):
                 return False
-            if t.element_size() > 2 and op.type != "Requantize":
+            if t.element_size() > 2 and op.type not in ("Requantize") | self.binary_elem_wise_add_mul_sub:
                 return False
             # check size
             if any(dim > 65536 for dim in t.shape):
@@ -197,15 +196,13 @@
             # check kernel size
             if op.attrs["padding"] == b"SAME" and (op.attrs["filter_width"] > 8 or op.attrs["filter_height"] > 8):
                 return False
-            if op.attrs["padding"] == b"VALID" and (op.attrs["filter_width"] > 256 or op.attrs["filter_height"] > 256):
+            if (op.attrs["padding"] == b"VALID" and
+                (op.attrs["filter_width"] * op.attrs["filter_height"] > 256 * 256 or op.attrs["filter_height"] > 256)):
                 return False
 
         if op.type in self.max_pooling_ops:
-            # check data type
-            if not ifm_tensor.dtype == ofm_tensor.dtype:
-                return False
-            # check kernel size
-            if op.attrs["filter_width"] > 256 or op.attrs["filter_height"] > 256:  # any padding
+            # check kernel size (any padding)
+            if op.attrs["filter_width"] * op.attrs["filter_height"] > 256 * 256 or op.attrs["filter_height"] > 256:
                 return False
         return True
 
@@ -220,8 +217,27 @@
     def check_element_wise_restrictions(self, op):
         # check data type
         ifm_tensor, ifm2_tensor, _, ofm_tensor = op.get_ifm_ifm2_weights_ofm()
-        if op.type in ("Minimum", "Maximum") and ifm_tensor.dtype != ofm_tensor.dtype:
+        # input and output datatype must match for these operators
+        if (op.type in self.binary_elem_wise_min_max_ops | self.unary_elem_wise_main_ops and
+            ifm_tensor.dtype != ofm_tensor.dtype):
             return False
+        if (op.type in self.binary_elem_wise_add_mul_sub):
+            # both inputs must have same type
+            if (ifm_tensor.dtype != ifm2_tensor.dtype):
+                return False
+            # signed input check
+            if (ifm_tensor.dtype.type & BaseType.Signed):
+                # output must be signed
+                if (ofm_tensor.dtype.type & BaseType.Unsigned):
+                    return False
+                # and 8, 16 or 32-bit
+                if (ofm_tensor.element_size() not in (1, 2, 4)):
+                    return False
+            # unsigned input check, output must be same type or int32
+            if (ifm_tensor.dtype.type & BaseType.Unsigned and not
+                (ifm_tensor.dtype == ofm_tensor.dtype or
+                 ofm_tensor.dtype == DataType.int32)):
+                return False
 
         # check batch size
         if len(ifm_tensor.shape) > 2 and ifm_tensor.shape[0] != 1: