MLBEDSW-2412 Refactor constraints for conv ops

Using a new system to report constraints, replaced existing
functionality for checking conv-like ops.
This new system will allow reporting of all constraints regardless of
any input network.

Signed-off-by: Michael McGeagh <michael.mcgeagh@arm.com>
Change-Id: If81177deca2a3b57c9dd9a3a08868cbc9cef0c23
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py
index 66c74fc..f4dd579 100644
--- a/ethosu/vela/supported_operators.py
+++ b/ethosu/vela/supported_operators.py
@@ -15,6 +15,8 @@
 # limitations under the License.
 # Description:
 # The SupportedOperators class which is a collection of all supported operators and parameter checks.
+from collections import defaultdict
+
 import numpy as np
 
 from .data_type import BaseType
@@ -43,6 +45,7 @@
     convolution_ops = set((Op.Conv2DBias, Op.Conv2D, Op.QuantizedConv2D,))
     depthwise_convolution_ops = set((Op.DepthwiseConv2DBias,))
     transpose_convolution_ops = set((Op.Conv2DBackpropInput,))
+    convolution_like_ops = convolution_ops | depthwise_convolution_ops | transpose_convolution_ops
     max_pooling_ops = Op.op_set(Op.is_maxpool_op)
     avg_pooling_ops = Op.op_set(Op.is_avgpool_op)
     pooling_ops = set((Op.ReduceSum,)) | max_pooling_ops | avg_pooling_ops
@@ -51,12 +54,8 @@
     mac_main_ops = (
         # RNN/LSTM/GRU
         set((Op.BlockLSTM,))
-        # convolutions
-        | convolution_ops
-        # depth-wise convolutions
-        | depthwise_convolution_ops
-        # transpose convolutions
-        | transpose_convolution_ops
+        # conv/depthwiseconv/transposeconv
+        | convolution_like_ops
         # pooling
         | pooling_ops
         # resizing/upscaling
@@ -88,17 +87,21 @@
     shapeless_input_ops = binary_elem_wise_main_ops | set((Op.Split, Op.SplitV,))
     supported_fused_activations = set((Op.Relu, Op.Relu6, Op.ReluN1To1, Op.Tanh, Op.Sigmoid, Op.LUT,))
     supported_operators = npu_pre_ops | mac_main_ops | elem_wise_main_ops | npu_post_ops | memory_only_ops
-    supported_dtypes = set((DataType.uint8, DataType.int8, DataType.int16, DataType.int32))
+    # Supported data types
+    supported_op_dtypes = set((DataType.uint8, DataType.int8, DataType.int16, DataType.int32))
+    supported_bias_dtypes = set((DataType.int32, DataType.int64))
     # Defined ranges for allowed values:
     tens_dim_range = (1, 65535)
+    stride_range = (1, 3)
+    dilation_range = (1, 2)
+    dilated_height_range = (1, 64)
+    dilated_product_range = (1, 64 * 64)
+    weights_limit = 127 * 65536
 
     def __init__(self):
         # Setup supported operator restriction checkers
         self.supported_operator_restrictions = {}
         self.supported_operator_restrictions.update(
-            {op: self.check_convolution_restrictions for op in SupportedOperators.convolution_ops}
-        )
-        self.supported_operator_restrictions.update(
             {op: self.check_depthwise_convolution_restrictions for op in SupportedOperators.depthwise_convolution_ops}
         )
         self.supported_operator_restrictions.update(
@@ -134,13 +137,37 @@
         self.generic_constraints.append(SupportedOperators.constraint_tens_quant_none_check)
         self.generic_constraints.append(SupportedOperators.constraint_tens_quant_scale)
         self.generic_constraints.append(SupportedOperators.constraint_faf)
+        # Setup specific constraints. The key in the dictionary must be a tuple of op types the constraints apply to
+        self.specific_constraints = defaultdict(list)
+        # Conv-like ops have the same checks applied to them:
+        conv_like_ops = tuple(SupportedOperators.convolution_like_ops)
+        self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_stride_type)
+        self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_stride_range)
+        self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_dilation_type)
+        self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_dilation_range)
+        self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_dilated_height_range)
+        self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_dilated_product_range)
+        self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_weights_type)
+        self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_weights_nonconst)
+        self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_weights_limit)
+        self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_bias_type)
+        self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_bias_40bit)
+        self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_batch_size)
+
+    def get_constraints_list(self, op_type):
+        constraint_list = list(self.generic_constraints)
+        for ops in self.specific_constraints:
+            if op_type in ops:
+                constraint_list.extend(self.specific_constraints[ops])
+        return constraint_list
 
     def is_operator_supported(self, op):
         if op.type not in SupportedOperators.supported_operators:
             if op.type not in (Op.Placeholder, Op.SubgraphInput, Op.Const):
                 print("Info: {} '{}' is not supported on the NPU. Placing on CPU instead".format(op.type, op.name))
             return False
-        for constraint in self.generic_constraints:
+
+        for constraint in self.get_constraints_list(op.type):
             valid, extra = constraint(op)
             if not valid:
                 print("Warning: {} '{}' is not supported on the NPU. Placing on CPU instead".format(op.type, op.name))
@@ -148,6 +175,7 @@
                 if extra:
                     print("   {}".format(extra))
                 return False
+
         if op.type in self.supported_operator_restrictions:
             return self.supported_operator_restrictions[op.type](op)
         return True
@@ -186,7 +214,7 @@
             if (tens.shape == []) and (op.type not in cls.shapeless_input_ops):
                 valid = False
                 extra.append(tens.name)
-        extra = "Op '{}' has shapeless input tensor(s): {}".format(op.name, ", ".join(extra))
+        extra = "Op has shapeless input tensor(s): {}".format(", ".join(extra))
         return valid, extra
 
     @staticmethod
@@ -202,15 +230,15 @@
         return valid, ", ".join(extra)
 
     @classmethod
-    @docstring_format_args([supported_dtypes])
+    @docstring_format_args([supported_op_dtypes])
     def constraint_tens_dtype(cls, op):
-        "Tensors must be of type: {}"
+        "Input(s), Output and Weight Tensors must be of type: {}"
         valid = True
         extra = []
         tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
         tensors = tensors if tensors else op.inputs
         for tens in tensors:
-            if tens.dtype not in cls.supported_dtypes:
+            if tens.dtype not in cls.supported_op_dtypes:
                 valid = False
                 extra.append("Tensor '{}' has data type: {}".format(tens.name, tens.dtype))
         return valid, ", ".join(extra)
@@ -227,13 +255,13 @@
             if (tens.dtype == DataType.int32) and (op.type not in cls.supported_int32_tensor_ops):
                 valid = False
                 extra.append(tens.name)
-        extra = "Op '{}' has int32 tensor(s): {}".format(op.name, ", ".join(extra))
+        extra = "Op has int32 tensor(s): {}".format(", ".join(extra))
         return valid, extra
 
     @classmethod
     @docstring_format_args(tens_dim_range)
     def constraint_tens_dimension(cls, op):
-        "Tensor dimensions must be in the range {}-{} (inclusive)"
+        "Tensor dimensions must be in the range [{}, {}]"
         tens_min, tens_max = cls.tens_dim_range
         valid = True
         extra = []
@@ -275,85 +303,129 @@
         "The fused activation function (if present) must be one of type: {}"
         faf = op.activation
         valid = (faf is None) or (faf in cls.supported_fused_activations)
-        extra = "Op '{}' has its fused activation function as: {}".format(op.name, faf)
+        extra = "Op has its fused activation function as: {}".format(faf)
+        return valid, extra
+
+    @staticmethod
+    def constraint_stride_type(op):
+        "Stride values for both width and height must be integer types"
+        w = op.attrs["stride_w"]
+        h = op.attrs["stride_h"]
+        valid = is_integer(w) and is_integer(h)
+        extra = "Op has stride WxH as: {}x{}".format(repr(w), repr(h))
         return valid, extra
 
     @classmethod
-    def check_convolution_restrictions(cls, op):
-        # check stride
-        stride_w, stride_h = op.attrs["stride_w"], op.attrs["stride_h"]
-        if not is_integer(stride_w) or not is_integer(stride_h):
-            print("Warning:", op.type, "has non-integer stride, placing on CPU")
-            return False
-        if not 1 <= stride_w <= 3 or not 1 <= stride_h <= 3:
-            print(
-                "Warning: {} has stride ({}, {}), only strides in range [1, 3] are allowed. Placing on CPU".format(
-                    op.type, stride_w, stride_h
-                )
-            )
-            return False
+    @docstring_format_args(stride_range)
+    def constraint_stride_range(cls, op):
+        "Stride values for both width and height must be in the range [{}, {}]"
+        w = op.attrs["stride_w"]
+        h = op.attrs["stride_h"]
+        stride_min, stride_max = cls.stride_range
+        valid = (stride_min <= w <= stride_max) and (stride_min <= h <= stride_max)
+        extra = "Op has stride WxH as: {}x{}".format(w, h)
+        return valid, extra
 
-        # check dilation
-        dilation_w_factor = op.attrs.get("dilation_w_factor", 1)
-        dilation_h_factor = op.attrs.get("dilation_h_factor", 1)
-        if not is_integer(dilation_w_factor) or not is_integer(dilation_h_factor):
-            print("Warning:", op.type, "has non-integer dilation factor, placing on CPU")
-            return False
-        if not 1 <= dilation_w_factor <= 2 or not 1 <= dilation_h_factor <= 2:
-            print(
-                "Warning:",
-                op.type,
-                "has dilation factors ({}, {}), only factors in range [1, 2] are allowed. Placing on CPU".format(
-                    dilation_w_factor, dilation_h_factor
-                ),
-            )
-            return False
+    @staticmethod
+    def constraint_dilation_type(op):
+        "Dilation factor values for both width and height must be integer types"
+        w = op.attrs.get("dilation_w_factor", 1)
+        h = op.attrs.get("dilation_h_factor", 1)
+        valid = is_integer(w) and is_integer(h)
+        extra = "Op has dilation factor WxH as: {}x{}".format(repr(w), repr(h))
+        return valid, extra
 
-        # check data type
-        ifm_tensor, _, weight_tensor, bias_tensor, _ = op.get_ifm_ifm2_weights_biases_ofm()
-        if weight_tensor.element_size() > 1:
-            print("Warning: only 8-bit weights are supported, placing on CPU")
-            return False
+    @classmethod
+    @docstring_format_args(dilation_range)
+    def constraint_dilation_range(cls, op):
+        "Dilation factor values for both width and height must be in the range [{}, {}]"
+        w = op.attrs.get("dilation_w_factor", 1)
+        h = op.attrs.get("dilation_h_factor", 1)
+        dilation_min, dilation_max = cls.dilation_range
+        valid = (dilation_min <= w <= dilation_max) and (dilation_min <= h <= dilation_max)
+        extra = "Op has dilation factor WxH as: {}x{}".format(w, h)
+        return valid, extra
 
-        if not cls.check_bias_restrictions(bias_tensor):
-            return False
+    @classmethod
+    @docstring_format_args(dilated_height_range)
+    def constraint_dilated_height_range(cls, op):
+        "Dilated kernel height must be in the range [{}, {}]"
+        h = (op.weights.shape[0] - 1) * op.attrs.get("dilation_h_factor", 1) + 1
+        dilated_height_min, dilated_height_max = cls.dilated_height_range
+        valid = dilated_height_min <= h <= dilated_height_max
+        extra = "Op has dilated kernel height as: {}".format(h)
+        return valid, extra
 
-        # check kernel size [HWIO]
-        dilated_weight_w = (weight_tensor.shape[1] - 1) * dilation_w_factor + 1
-        dilated_weight_h = (weight_tensor.shape[0] - 1) * dilation_h_factor + 1
+    @classmethod
+    @docstring_format_args(dilated_product_range)
+    def constraint_dilated_product_range(cls, op):
+        "Product of dilated kernel width and height must be in the range [{}, {}]"
+        weights = op.weights
+        w = (weights.shape[1] - 1) * op.attrs.get("dilation_w_factor", 1) + 1
+        h = (weights.shape[0] - 1) * op.attrs.get("dilation_h_factor", 1) + 1
+        product = w * h
+        dilated_product_min, dilated_product_max = cls.dilated_product_range
+        valid = dilated_product_min <= product <= dilated_product_max
+        extra = "Op has product of dilated kernel width and height as: {}".format(product)
+        return valid, extra
 
-        # kernel limits
-        if not 1 <= dilated_weight_h <= 64:
-            print("Warning:", op.type, "has kernel height outside of range [1, 64], placing on CPU")
-            return False
-        if not 1 <= dilated_weight_w * dilated_weight_h <= 64 * 64:
-            print(
-                "Warning: product of kernel width and height must be >= 1 and not exceed 64 * 64 ({}),".format(64 * 64),
-                "placing on CPU",
-            )
-            return False
+    @staticmethod
+    def constraint_weights_type(op):
+        "Weight Tensor must be 8-bit"
+        weights = op.weights
+        valid = weights.element_size() == 1
+        extra = "Tensor '{}' is {}-bit".format(weights.name, int(weights.element_size() * 8))
+        return valid, extra
 
-        # check non const weights
-        if weight_tensor.values is None:
-            print("Warning:", op.type, "has non-constant weights, placing on CPU")
-            return False
+    @staticmethod
+    def constraint_weights_nonconst(op):
+        "Weight tensor cannot be non-constant"
+        weights = op.weights
+        valid = weights.values is not None
+        extra = "Tensor '{}' has non-constant values".format(weights.name)
+        return valid, extra
 
-        # check weight sums over [HWI]
-        zero_point = weight_tensor.quantization.zero_point
-        quant_weights = weight_tensor.quant_values.astype(np.int64)
-        weights = quant_weights - zero_point
-        totals = np.sum(np.absolute(weights), axis=(0, 1, 2))
+    @classmethod
+    @docstring_format_args([weights_limit])
+    def constraint_weights_limit(cls, op):
+        "The sum of the weights cannot exceed {}"
+        weights = op.weights
+        values = weights.quant_values.astype(np.int64) - weights.quantization.zero_point
+        limit = np.amax(np.sum(np.absolute(values), axis=(0, 1, 2)))
+        valid = limit <= cls.weights_limit
+        extra = "Tensor '{}' has the sum of weights: {}".format(weights.name, limit)
+        return valid, extra
 
-        if np.amax(totals) > 127 * 65536:
-            print("Warning: sum of weights exceeds 127 * 65536 ({}), placing on CPU".format(127 * 65536))
-            return False
+    @classmethod
+    @docstring_format_args([supported_bias_dtypes])
+    def constraint_bias_type(cls, op):
+        "Optional Bias Tensor must be of type: {}"
+        valid = True
+        extra = ""
+        bias = op.bias
+        if bias:
+            valid = bias.dtype in cls.supported_bias_dtypes
+            extra = "Tensor '{}' has data type: {}".format(bias.name, bias.dtype)
+        return valid, extra
 
-        # check batch size
-        if ifm_tensor.shape[0] != 1:
-            print("Warning: only batch sizes of 1 are supported, placing on CPU")
-            return False
+    @staticmethod
+    def constraint_bias_40bit(op):
+        "Optional Bias Tensor values must fit within 40-bits"
+        valid = True
+        extra = ""
+        bias = op.bias
+        if bias and bias.dtype == DataType.int64:
+            valid = all(len(bin(quant_value)[2:]) <= 40 for quant_value in bias.quant_values)
+            extra = "Tensor '{}' has values larger than 40-bits".format(bias.name)
+        return valid, extra
 
-        return True
+    @staticmethod
+    def constraint_batch_size(op):
+        "IFM Tensor batch size must be 1"
+        ifm = op.ifm
+        valid = ifm.shape[0] == 1
+        extra = "Tensor '{}' has batch size: {}".format(ifm.name, ifm.shape[0])
+        return valid, extra
 
     @classmethod
     def check_depthwise_convolution_restrictions(cls, op):
@@ -368,7 +440,7 @@
                 "Placing on CPU",
             )
             return False
-        return cls.check_convolution_restrictions(op)
+        return True
 
     @classmethod
     def check_transpose_convolution_restrictions(cls, op):
@@ -403,8 +475,7 @@
                     "minus difference between kernel size and stride. Placing on CPU",
                 )
                 return False
-
-        return cls.check_convolution_restrictions(op)
+        return True
 
     @classmethod
     def check_pooling_restrictions(cls, op):