MLBEDSW-2412 All constraints have been refactored

All existing constraints have now been refactored using the new
framework.

Signed-off-by: Michael McGeagh <michael.mcgeagh@arm.com>
Change-Id: Ic9ba0d7040cb9f114b959a949bfdf777f86752c7
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py
index 24c7291..ddfb8ed 100644
--- a/ethosu/vela/supported_operators.py
+++ b/ethosu/vela/supported_operators.py
@@ -25,7 +25,6 @@
 from .operation import get_slice_offsets
 from .operation import Op
 from .tensor import check_quantized_tens_scaling_equal
-from .tensor import check_tens_quantized
 
 
 # Custom decorator function to allow formatting docstrings containing "{}"
@@ -74,7 +73,8 @@
     supported_int32_tensor_ops = (
         set((Op.ReduceSum, Op.CLZ,)) | binary_elem_wise_add_mul_sub | binary_elem_wise_shift_ops
     )
-    activation_ops = set((Op.Relu, Op.Relu6, Op.ReluN1To1, Op.Sigmoid, Op.Tanh, Op.Softmax,))
+    relu_ops = Op.op_set(Op.is_relu_op)
+    activation_ops = relu_ops | set((Op.Tanh, Op.Sigmoid, Op.Softmax,))
     npu_post_ops = (
         # activation functions
         activation_ops
@@ -87,7 +87,7 @@
     concat_ops = set((Op.Concat, Op.ConcatTFLite, Op.PackReshaped, Op.Pack,))
     memory_only_ops = set((Op.Squeeze, Op.Reshape, Op.QuantizedReshape, Op.ExpandDims,)) | concat_ops | split_ops
     shapeless_input_ops = binary_elem_wise_main_ops | set((Op.Split, Op.SplitV,))
-    supported_fused_activations = set((Op.Relu, Op.Relu6, Op.ReluN1To1, Op.Tanh, Op.Sigmoid, Op.LUT,))
+    supported_fused_activations = relu_ops | set((Op.Tanh, Op.Sigmoid, Op.LUT,))
     supported_operators = npu_pre_ops | mac_main_ops | elem_wise_main_ops | npu_post_ops | memory_only_ops
     # Supported data types
     supported_op_dtypes = set((DataType.uint8, DataType.int8, DataType.int16, DataType.int32))
@@ -99,39 +99,17 @@
     dilated_height_range = (1, 64)
     dilated_product_range = (1, 64 * 64)
     weights_limit = 127 * 65536
+    filter_range = (1, 8)
+    filter_height_range = (1, 256)
+    filter_product_range = (1, 256 * 256)
 
     def __init__(self):
-        # Setup supported operator restriction checkers
-        self.supported_operator_restrictions = {}
-        self.supported_operator_restrictions.update(
-            {op: self.check_depthwise_convolution_restrictions for op in SupportedOperators.depthwise_convolution_ops}
-        )
-        self.supported_operator_restrictions.update(
-            {op: self.check_transpose_convolution_restrictions for op in SupportedOperators.transpose_convolution_ops}
-        )
-        self.supported_operator_restrictions.update(
-            {op: self.check_pooling_restrictions for op in SupportedOperators.pooling_ops}
-        )
-        self.supported_operator_restrictions.update(
-            {op: self.check_resize_restrictions for op in SupportedOperators.resizing_ops}
-        )
-        self.supported_operator_restrictions.update(
-            {op: self.check_vector_product_restrictions for op in SupportedOperators.fc_vector_products}
-        )
-        self.supported_operator_restrictions.update(
-            {op: self.check_element_wise_restrictions for op in SupportedOperators.elem_wise_main_ops}
-        )
-        self.supported_operator_restrictions.update(
-            {op: self.check_memory_only_restrictions for op in SupportedOperators.memory_only_ops}
-        )
-        self.supported_operator_restrictions.update(
-            {op: self.check_activation_ops for op in SupportedOperators.activation_ops}
-        )
         # Setup the generic constraints. Note: the order matters
         self.generic_constraints = []
+        self.generic_constraints.append(SupportedOperators.constraint_tens_no_dynamic)
         self.generic_constraints.append(SupportedOperators.constraint_tens_defined_shape)
-        self.generic_constraints.append(SupportedOperators.constraint_tens_output_shapeless)
-        self.generic_constraints.append(SupportedOperators.constraint_tens_input_shapeless)
+        self.generic_constraints.append(SupportedOperators.constraint_tens_output_scalar)
+        self.generic_constraints.append(SupportedOperators.constraint_tens_input_scalar)
         self.generic_constraints.append(SupportedOperators.constraint_tens_shape_size)
         self.generic_constraints.append(SupportedOperators.constraint_tens_dtype)
         self.generic_constraints.append(SupportedOperators.constraint_tens_int32_ops)
@@ -139,76 +117,173 @@
         self.generic_constraints.append(SupportedOperators.constraint_tens_quant_none_check)
         self.generic_constraints.append(SupportedOperators.constraint_tens_quant_scale)
         self.generic_constraints.append(SupportedOperators.constraint_faf)
-        # Setup specific constraints. The key in the dictionary must be a tuple of op types the constraints apply to
-        self.specific_constraints = defaultdict(list)
-        # Conv-like ops have the same checks applied to them:
-        conv_like_ops = tuple(SupportedOperators.convolution_like_ops)
-        self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_stride_type)
-        self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_stride_range)
-        self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_dilation_type)
-        self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_dilation_range)
-        self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_dilated_height_range)
-        self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_dilated_product_range)
-        self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_weights_type)
-        self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_weights_nonconst)
-        self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_weights_limit)
-        self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_bias_type)
-        self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_bias_40bit)
-        self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_batch_size)
 
-    def get_constraints_list(self, op_type):
-        constraint_list = list(self.generic_constraints)
-        for ops in self.specific_constraints:
-            if op_type in ops:
-                constraint_list.extend(self.specific_constraints[ops])
-        return constraint_list
+        # Setup specific constraints. Note: the order matters
+        self.specific_constraints = defaultdict(list)
+
+        # Conv-like checks:
+        for op_type in SupportedOperators.convolution_like_ops:
+            self.specific_constraints[op_type].append(SupportedOperators.constraint_stride_type)
+            self.specific_constraints[op_type].append(SupportedOperators.constraint_stride_range)
+            self.specific_constraints[op_type].append(SupportedOperators.constraint_dilation_type)
+            self.specific_constraints[op_type].append(SupportedOperators.constraint_dilation_range)
+            self.specific_constraints[op_type].append(SupportedOperators.constraint_dilated_height_range)
+            self.specific_constraints[op_type].append(SupportedOperators.constraint_dilated_product_range)
+            self.specific_constraints[op_type].append(SupportedOperators.constraint_weights_type)
+            self.specific_constraints[op_type].append(SupportedOperators.constraint_weights_const)
+            self.specific_constraints[op_type].append(SupportedOperators.constraint_weights_limit)
+            self.specific_constraints[op_type].append(SupportedOperators.constraint_bias_type)
+            self.specific_constraints[op_type].append(SupportedOperators.constraint_bias_40bit)
+            self.specific_constraints[op_type].append(SupportedOperators.constraint_batch_size)
+        # Depthwise Conv specific checks:
+        for op_type in SupportedOperators.depthwise_convolution_ops:
+            self.specific_constraints[op_type].append(SupportedOperators.constraint_depth_multiplier)
+        # Transpose Conv specific checks:
+        for op_type in SupportedOperators.transpose_convolution_ops:
+            self.specific_constraints[op_type].append(SupportedOperators.constraint_tconv_stride)
+            self.specific_constraints[op_type].append(SupportedOperators.constraint_tconv_same)
+            self.specific_constraints[op_type].append(SupportedOperators.constraint_tconv_valid)
+
+        # Pooling checks:
+        for op_type in SupportedOperators.pooling_ops:
+            self.specific_constraints[op_type].append(SupportedOperators.constraint_batch_size)
+            self.specific_constraints[op_type].append(SupportedOperators.constraint_stride_type)
+            self.specific_constraints[op_type].append(SupportedOperators.constraint_stride_range)
+        # AVG pooling specific checks:
+        for op_type in SupportedOperators.avg_pooling_ops:
+            self.specific_constraints[op_type].append(SupportedOperators.constraint_matching_in_out_types)
+            self.specific_constraints[op_type].append(SupportedOperators.constraint_filter_type)
+            self.specific_constraints[op_type].append(SupportedOperators.constraint_filter_range)
+            self.specific_constraints[op_type].append(SupportedOperators.constraint_filter_height_range_valid_pad)
+            self.specific_constraints[op_type].append(SupportedOperators.constraint_filter_product_range_valid_pad)
+        # MAX pooling specific checks:
+        for op_type in SupportedOperators.max_pooling_ops:
+            self.specific_constraints[op_type].append(SupportedOperators.constraint_matching_in_out_types)
+            self.specific_constraints[op_type].append(SupportedOperators.constraint_filter_type)
+            self.specific_constraints[op_type].append(SupportedOperators.constraint_filter_height_range)
+            self.specific_constraints[op_type].append(SupportedOperators.constraint_filter_product_range)
+        # TODO: Check ReduceSum restrictions
+
+        # Relu specific checks:
+        for op_type in SupportedOperators.relu_ops:
+            self.specific_constraints[op_type].append(SupportedOperators.constraint_quant_scale_inf)
+
+        # Resizing specific checks:
+        for op_type in SupportedOperators.resizing_ops:
+            self.specific_constraints[op_type].append(SupportedOperators.constraint_resize)
+
+        # Vector Product specific checks:
+        for op_type in SupportedOperators.fc_vector_products:
+            self.specific_constraints[op_type].append(SupportedOperators.constraint_weights_type)
+            self.specific_constraints[op_type].append(SupportedOperators.constraint_weights_const)
+            self.specific_constraints[op_type].append(SupportedOperators.constraint_bias_type)
+            self.specific_constraints[op_type].append(SupportedOperators.constraint_bias_40bit)
+
+        # Concat specific checks:
+        for op_type in (Op.Concat, Op.ConcatTFLite):
+            self.specific_constraints[op_type].append(SupportedOperators.constraint_axis_exists)
+            self.specific_constraints[op_type].append(SupportedOperators.constraint_axis_valid)
+            self.specific_constraints[op_type].append(SupportedOperators.constraint_matching_dimensionality)
+            self.specific_constraints[op_type].append(SupportedOperators.constraint_valid_dimensions)
+
+        # Element-wise checks:
+        for op_type in SupportedOperators.elem_wise_main_ops:
+            self.specific_constraints[op_type].append(SupportedOperators.constraint_elemwise_batch_size)
+            self.specific_constraints[op_type].append(SupportedOperators.constraint_matching_either_shapes)
+        # Unary specific checks:
+        for op_type in SupportedOperators.unary_elem_wise_main_ops:
+            self.specific_constraints[op_type].append(SupportedOperators.constraint_matching_in_out_types)
+        # Binary Min/Max specific checks:
+        for op_type in SupportedOperators.binary_elem_wise_min_max_ops:
+            self.specific_constraints[op_type].append(SupportedOperators.constraint_matching_in_out_types)
+            self.specific_constraints[op_type].append(SupportedOperators.constraint_matching_quantization_parameters)
+        # Binary Add/Mul/Sub specific checks:
+        for op_type in SupportedOperators.binary_elem_wise_add_mul_sub:
+            self.specific_constraints[op_type].append(SupportedOperators.constraint_matching_inputs_types)
+            self.specific_constraints[op_type].append(SupportedOperators.constraint_matching_signed)
+            self.specific_constraints[op_type].append(SupportedOperators.constraint_unsigned_valid)
+        # Binary Shift specific checks:
+        for op_type in SupportedOperators.binary_elem_wise_shift_ops:
+            self.specific_constraints[op_type].append(SupportedOperators.constraint_inputs_int32)
+
+        # SHL specific checks:
+        self.specific_constraints[Op.SHL].append(SupportedOperators.constraint_output_int32)
+
+        # CLZ specific checks:
+        self.specific_constraints[Op.CLZ].append(SupportedOperators.constraint_output_int32)
+
+        # Softmax specific checks:
+        self.specific_constraints[Op.Softmax].append(SupportedOperators.constraint_matching_shapes)
+        self.specific_constraints[Op.Softmax].append(SupportedOperators.constraint_matching_in_out_types)
+
+        # SplitV specific checks:
+        self.specific_constraints[Op.SplitV].append(SupportedOperators.constraint_splitv_inferred)
+
+        # StridedSlice specific checks:
+        self.specific_constraints[Op.StridedSlice].append(SupportedOperators.constraint_stridedslice_input_count)
+        self.specific_constraints[Op.StridedSlice].append(SupportedOperators.constraint_stridedslice_inputs_const)
+        self.specific_constraints[Op.StridedSlice].append(SupportedOperators.constraint_stridedslice_tens_size_matches)
+        self.specific_constraints[Op.StridedSlice].append(SupportedOperators.constraint_stridedslice_stride_values)
+        self.specific_constraints[Op.StridedSlice].append(SupportedOperators.constraint_ellipsis_mask)
+        self.specific_constraints[Op.StridedSlice].append(SupportedOperators.constraint_axis_masks)
+        self.specific_constraints[Op.StridedSlice].append(SupportedOperators.constraint_slice_ranges)
+
+        # LeakyRelu specific checks:
+        self.specific_constraints[Op.LeakyRelu].append(SupportedOperators.constraint_alpha_valid)
 
     def is_operator_supported(self, op):
         if op.type not in SupportedOperators.supported_operators:
             if op.type not in (Op.Placeholder, Op.SubgraphInput, Op.Const):
-                print("Info: {} '{}' is not supported on the NPU. Placing on CPU instead".format(op.type, op.name))
+                print(f"Info: {op.type} '{op.name}' is not supported on the NPU. Placing on CPU instead")
             return False
 
-        for constraint in self.get_constraints_list(op.type):
+        for constraint in self.generic_constraints + self.specific_constraints[op.type]:
             valid, extra = constraint(op)
             if not valid:
-                print("Warning: {} '{}' is not supported on the NPU. Placing on CPU instead".format(op.type, op.name))
-                print(" - {}".format(constraint.__doc__))
+                print(f"Warning: {op.type} '{op.name}' is not supported on the NPU. Placing on CPU instead")
+                print(f" - {constraint.__doc__}")
                 if extra:
-                    print("   {}".format(extra))
+                    print(f"   {extra}")
                 return False
 
-        if op.type in self.supported_operator_restrictions:
-            return self.supported_operator_restrictions[op.type](op)
         return True
 
     @staticmethod
+    def constraint_tens_no_dynamic(op):
+        "Input(s) and Output tensors must not be dynamic"
+        valid = True
+        extra = []
+        tensors = [tens for tens in op.inputs + op.outputs if tens]
+        for tens in tensors:
+            if (tens.shape == []) and (tens.values is None):
+                valid = False
+                extra.append(tens.name)
+        extra = ", ".join(extra)
+        return valid, f"Op has dynamic tensor(s): {extra}"
+
+    @staticmethod
     def constraint_tens_defined_shape(op):
-        "Input(s) and Output Tensors must have a defined shape"
+        "Input(s) and Output tensors must have a defined shape"
         valid = True
         extra = []
         tensors = [tens for tens in op.inputs + op.outputs if tens]
         for tens in tensors:
             if not tens.has_fully_defined_shape():
                 valid = False
-                extra.append("Tensor '{}' has shape: {}".format(tens.name, tens.shape))
+                extra.append(f"Tensor '{tens.name}' has shape: {tens.shape}")
         return valid, ", ".join(extra)
 
     @staticmethod
-    def constraint_tens_output_shapeless(op):
-        "Scalar or Broadcasting Tensors are only valid for Input Tensors"
-        valid = True
-        extra = []
-        for tens in op.outputs:
-            if tens.shape == []:
-                valid = False
-                extra.append("Output Tensor '{}' is shapeless".format(tens.name))
-        return valid, ", ".join(extra)
+    def constraint_tens_output_scalar(op):
+        "Output tensors cannot be scalar"
+        ofm = op.ofm
+        valid = ofm.shape != []
+        return valid, f"Output Tensor '{ofm.name}' is scalar"
 
     @classmethod
     @docstring_format_args([shapeless_input_ops])
-    def constraint_tens_input_shapeless(cls, op):
-        "Scalar or Broadcasting Input Tensors are only valid for op type: {}"
+    def constraint_tens_input_scalar(cls, op):
+        "Scalar Input tensors are only valid for op type: {}"
         valid = True
         extra = []
         tensors = [tens for tens in op.inputs if tens]
@@ -216,33 +291,34 @@
             if (tens.shape == []) and (op.type not in cls.shapeless_input_ops):
                 valid = False
                 extra.append(tens.name)
-        extra = "Op has shapeless input tensor(s): {}".format(", ".join(extra))
-        return valid, extra
+        extra = ", ".join(extra)
+        return valid, f"Op has scalar input tensor(s): {extra}"
 
     @staticmethod
     def constraint_tens_shape_size(op):
-        "Input(s) and Output Tensors must not be greater than 4D"
+        "Input(s) and Output tensors must not be greater than 4D"
         valid = True
         extra = []
         tensors = [tens for tens in op.inputs + op.outputs if tens]
         for tens in tensors:
             if len(tens.shape) > 4:
                 valid = False
-                extra.append("Tensor '{}' has shape: {}".format(tens.name, tens.shape))
+                extra.append(f"Tensor '{tens.name}' has shape: {tens.shape}")
         return valid, ", ".join(extra)
 
     @classmethod
     @docstring_format_args([supported_op_dtypes])
     def constraint_tens_dtype(cls, op):
-        "Input(s), Output and Weight Tensors must be of type: {}"
+        "Tensors must be of type: {}"
         valid = True
         extra = []
         tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
-        tensors = tensors if tensors else op.inputs
+        if not tensors:
+            tensors = [tens for tens in op.inputs if tens]
         for tens in tensors:
             if tens.dtype not in cls.supported_op_dtypes:
                 valid = False
-                extra.append("Tensor '{}' has data type: {}".format(tens.name, tens.dtype))
+                extra.append(f"Tensor '{tens.name}' has data type: {tens.dtype}")
         return valid, ", ".join(extra)
 
     @classmethod
@@ -252,13 +328,14 @@
         valid = True
         extra = []
         tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
-        tensors = tensors if tensors else op.inputs
+        if not tensors:
+            tensors = [tens for tens in op.inputs if tens]
         for tens in tensors:
             if (tens.dtype == DataType.int32) and (op.type not in cls.supported_int32_tensor_ops):
                 valid = False
                 extra.append(tens.name)
-        extra = "Op has int32 tensor(s): {}".format(", ".join(extra))
-        return valid, extra
+        extra = ", ".join(extra)
+        return valid, f"Op has int32 tensor(s): {extra}"
 
     @classmethod
     @docstring_format_args(tens_dim_range)
@@ -268,35 +345,37 @@
         valid = True
         extra = []
         tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
-        tensors = tensors if tensors else op.inputs
+        if not tensors:
+            tensors = [tens for tens in op.inputs if tens]
         for tens in tensors:
             if not all(tens_min <= dim <= tens_max for dim in tens.shape):
                 valid = False
-                extra.append("Tensor '{}' has shape: {}".format(tens.name, tens.shape))
+                extra.append(f"Tensor '{tens.name}' has shape: {tens.shape}")
         return valid, ", ".join(extra)
 
     @staticmethod
     def constraint_tens_quant_none_check(op):
-        "Tensors must have quantization parameters"
+        "Input(s), Output and Weight tensors must have quantization parameters"
         valid = True
         extra = []
         tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
         for tens in tensors:
             if tens.quantization is None:
                 valid = False
-                extra.append("Tensor '{}' has no quantization parameters".format(tens.name))
-        return valid, ", ".join(extra)
+                extra.append(tens.name)
+        extra = ", ".join(extra)
+        return valid, f"Op has tensors with missing quantization parameters: {extra}"
 
     @staticmethod
     def constraint_tens_quant_scale(op):
-        "Tensors with quantization scales must be finite"
+        "Input(s), Output and Weight tensors with quantization scales must be finite"
         valid = True
         extra = []
         tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
         for tens in tensors:
             if (tens.quantization.scale_f32 is not None) and np.isinf(tens.quantization.scale_f32).any():
                 valid = False
-                extra.append("Tensor '{}' has quantization scale: {}".format(tens.name, tens.quantization.scale_f32))
+                extra.append(f"Tensor '{tens.name}' has quantization scale: {tens.quantization.scale_f32}")
         return valid, ", ".join(extra)
 
     @classmethod
@@ -305,87 +384,71 @@
         "The fused activation function (if present) must be one of type: {}"
         faf = op.activation
         valid = (faf is None) or (faf in cls.supported_fused_activations)
-        extra = "Op has its fused activation function as: {}".format(faf)
-        return valid, extra
+        return valid, f"Op has its fused activation function as: {faf}"
 
     @staticmethod
     def constraint_stride_type(op):
         "Stride values for both width and height must be integer types"
-        w = op.attrs["stride_w"]
-        h = op.attrs["stride_h"]
+        w, h = op.get_kernel_stride()
         valid = is_integer(w) and is_integer(h)
-        extra = "Op has stride WxH as: {}x{}".format(repr(w), repr(h))
-        return valid, extra
+        return valid, f"Op has stride WxH as: {repr(w)}x{repr(h)}"
 
     @classmethod
     @docstring_format_args(stride_range)
     def constraint_stride_range(cls, op):
         "Stride values for both width and height must be in the range [{}, {}]"
-        w = op.attrs["stride_w"]
-        h = op.attrs["stride_h"]
+        w, h = op.get_kernel_stride()
         stride_min, stride_max = cls.stride_range
         valid = (stride_min <= w <= stride_max) and (stride_min <= h <= stride_max)
-        extra = "Op has stride WxH as: {}x{}".format(w, h)
-        return valid, extra
+        return valid, f"Op has stride WxH as: {w}x{h}"
 
     @staticmethod
     def constraint_dilation_type(op):
         "Dilation factor values for both width and height must be integer types"
-        w = op.attrs.get("dilation_w_factor", 1)
-        h = op.attrs.get("dilation_h_factor", 1)
+        w, h = op.get_kernel_dilation()
         valid = is_integer(w) and is_integer(h)
-        extra = "Op has dilation factor WxH as: {}x{}".format(repr(w), repr(h))
-        return valid, extra
+        return valid, f"Op has dilation factor WxH as: {repr(w)}x{repr(h)}"
 
     @classmethod
     @docstring_format_args(dilation_range)
     def constraint_dilation_range(cls, op):
         "Dilation factor values for both width and height must be in the range [{}, {}]"
-        w = op.attrs.get("dilation_w_factor", 1)
-        h = op.attrs.get("dilation_h_factor", 1)
+        w, h = op.get_kernel_dilation()
         dilation_min, dilation_max = cls.dilation_range
         valid = (dilation_min <= w <= dilation_max) and (dilation_min <= h <= dilation_max)
-        extra = "Op has dilation factor WxH as: {}x{}".format(w, h)
-        return valid, extra
+        return valid, f"Op has dilation factor WxH as: {w}x{h}"
 
     @classmethod
     @docstring_format_args(dilated_height_range)
     def constraint_dilated_height_range(cls, op):
         "Dilated kernel height must be in the range [{}, {}]"
-        h = (op.weights.shape[0] - 1) * op.attrs.get("dilation_h_factor", 1) + 1
+        h = op.kernel.area_height()
         dilated_height_min, dilated_height_max = cls.dilated_height_range
         valid = dilated_height_min <= h <= dilated_height_max
-        extra = "Op has dilated kernel height as: {}".format(h)
-        return valid, extra
+        return valid, f"Op has dilated kernel height as: {h}"
 
     @classmethod
     @docstring_format_args(dilated_product_range)
     def constraint_dilated_product_range(cls, op):
         "Product of dilated kernel width and height must be in the range [{}, {}]"
-        weights = op.weights
-        w = (weights.shape[1] - 1) * op.attrs.get("dilation_w_factor", 1) + 1
-        h = (weights.shape[0] - 1) * op.attrs.get("dilation_h_factor", 1) + 1
-        product = w * h
+        product = op.kernel.area_width() * op.kernel.area_height()
         dilated_product_min, dilated_product_max = cls.dilated_product_range
         valid = dilated_product_min <= product <= dilated_product_max
-        extra = "Op has product of dilated kernel width and height as: {}".format(product)
-        return valid, extra
+        return valid, f"Op has product of dilated kernel width and height as: {product}"
 
     @staticmethod
     def constraint_weights_type(op):
-        "Weight Tensor must be 8-bit"
+        "Weight tensor must be 8-bit"
         weights = op.weights
         valid = weights.element_size() == 1
-        extra = "Tensor '{}' is {}-bit".format(weights.name, int(weights.element_size() * 8))
-        return valid, extra
+        return valid, f"Tensor '{weights.name}' is {int(weights.element_size() * 8)}-bit"
 
     @staticmethod
-    def constraint_weights_nonconst(op):
-        "Weight tensor cannot be non-constant"
+    def constraint_weights_const(op):
+        "Weight tensor must be constant"
         weights = op.weights
         valid = weights.values is not None
-        extra = "Tensor '{}' has non-constant values".format(weights.name)
-        return valid, extra
+        return valid, f"Tensor '{weights.name}' has non-constant values"
 
     @classmethod
     @docstring_format_args([weights_limit])
@@ -395,405 +458,409 @@
         values = weights.quant_values.astype(np.int64) - weights.quantization.zero_point
         limit = np.amax(np.sum(np.absolute(values), axis=(0, 1, 2)))
         valid = limit <= cls.weights_limit
-        extra = "Tensor '{}' has the sum of weights: {}".format(weights.name, limit)
-        return valid, extra
+        return valid, f"Tensor '{weights.name}' has the sum of weights: {limit}"
 
     @classmethod
     @docstring_format_args([supported_bias_dtypes])
     def constraint_bias_type(cls, op):
-        "Optional Bias Tensor must be of type: {}"
-        valid = True
-        extra = ""
+        "Optional Bias tensor must be of type: {}"
         bias = op.bias
         if bias:
             valid = bias.dtype in cls.supported_bias_dtypes
-            extra = "Tensor '{}' has data type: {}".format(bias.name, bias.dtype)
-        return valid, extra
+            return valid, f"Tensor '{bias.name}' has data type: {bias.dtype}"
+        return True, "Op has no bias tensor"
 
     @staticmethod
     def constraint_bias_40bit(op):
-        "Optional Bias Tensor values must fit within 40-bits"
-        valid = True
-        extra = ""
+        "Optional Bias tensor values must fit within 40-bits"
         bias = op.bias
         if bias and bias.dtype == DataType.int64:
             valid = all(len(bin(quant_value)[2:]) <= 40 for quant_value in bias.quant_values)
-            extra = "Tensor '{}' has values larger than 40-bits".format(bias.name)
-        return valid, extra
+            return valid, f"Tensor '{bias.name}' has values larger than 40-bits"
+        return True, "Op has no bias tensor, or it fits in 40-bit"
 
     @staticmethod
     def constraint_batch_size(op):
         "IFM Tensor batch size must be 1"
         ifm = op.ifm
         valid = ifm.shape[0] == 1
-        extra = "Tensor '{}' has batch size: {}".format(ifm.name, ifm.shape[0])
+        return valid, f"Tensor '{ifm.name}' has batch size: {ifm.shape[0]}"
+
+    @staticmethod
+    def constraint_quant_scale_inf(op):
+        "The IFM quantization scale divided by the OFM quantization scale must not be infinite"
+        ifm_scale = op.ifm.quantization.scale_f32
+        ofm_scale = op.ofm.quantization.scale_f32
+        valid = not np.isinf(ifm_scale / ofm_scale)
+        return valid, f"Op has infinite quantization scale. ifm_scale={ifm_scale} ofm_scale={ofm_scale}"
+
+    @staticmethod
+    def constraint_depth_multiplier(op):
+        "For depth multipliers > 1, IFM channels must be 1 and OFM channels must be equal to the depth multiplier"
+        depth_multiplier = op.attrs.get("depth_multiplier", 1)
+        if depth_multiplier > 1:
+            ifm_channels = op.ifm.shape[3]
+            ofm_channels = op.ofm.shape[3]
+            valid = (ifm_channels == 1) and (ofm_channels == depth_multiplier)
+            extra = (
+                f"Op has ifm_channels={ifm_channels}, ofm_channels={ofm_channels}"
+                f" and depth_multiplier={depth_multiplier}"
+            )
+            return valid, extra
+        return True, "Op has depth_multiplier=1"
+
+    @staticmethod
+    def constraint_tconv_stride(op):
+        "Stride values for both width and height must be 2"
+        w = op.kernel.stride.x
+        h = op.kernel.stride.y
+        valid = (w == 2) and (h == 2)
+        return valid, f"Op has stride WxH as: {w}x{h}"
+
+    @staticmethod
+    def constraint_tconv_same(op):
+        "SAME padding: OFM dimensions must equal IFM dimensions multiplied by stride"
+        if op.attrs["padding"] == b"SAME":
+            w = op.kernel.stride.x
+            h = op.kernel.stride.y
+            ifm_shape = op.ifm.shape
+            ofm_shape = op.ofm.shape
+            valid = (ofm_shape[1] == (ifm_shape[1] * h)) and (ofm_shape[2] == (ifm_shape[2] * w))
+            return valid, f"Op has ifm_shape={ifm_shape}, ofm_shape={ofm_shape} and stride WxH as {w}x{h}"
+        return True, "Op has padding=VALID"
+
+    @staticmethod
+    def constraint_tconv_valid(op):
+        """VALID padding: OFM dimensions must equal IFM dimensions multiplied by stride,
+                  minus difference between kernel size and stride"""
+        if op.attrs["padding"] == b"VALID":
+            s_w = op.kernel.stride.x
+            s_h = op.kernel.stride.y
+            k_w = op.kernel.width
+            k_h = op.kernel.height
+            ifm_shape = op.ifm.shape
+            ofm_shape = op.ofm.shape
+            height_check = ofm_shape[1] == (ifm_shape[1] * s_h + max(k_h - s_h, 0))
+            width_check = ofm_shape[2] == (ifm_shape[2] * s_w + max(k_w - s_w, 0))
+            valid = height_check and width_check
+            extra = (
+                f"Op has ifm_shape={ifm_shape}, ofm_shape={ofm_shape},"
+                f" stride WxH as {s_w}x{s_h} and kernel WxH as {k_w}x{k_h}"
+            )
+            return valid, extra
+        return True, "Op has padding=SAME"
+
+    @staticmethod
+    def constraint_matching_in_out_types(op):
+        "IFM and OFM data types must match"
+        ifm_dtype = op.ifm.dtype
+        ofm_dtype = op.ofm.dtype
+        valid = ifm_dtype == ofm_dtype
+        return valid, f"Op has ifm_dtype={ifm_dtype} and ofm_dtype={ofm_dtype}"
+
+    @staticmethod
+    def constraint_filter_type(op):
+        "Kernel filter values for both width and height must be integer types"
+        w = op.kernel.width
+        h = op.kernel.height
+        valid = is_integer(w) and is_integer(h)
+        return valid, f"Op has kernel filter WxH as: {repr(w)}x{repr(h)}"
+
+    @classmethod
+    @docstring_format_args(filter_range)
+    def constraint_filter_range(cls, op):
+        "Kernel filter values for both width and height must be in the range [{}, {}]"
+        if op.attrs["padding"] == b"SAME":
+            w = op.kernel.width
+            h = op.kernel.height
+            filter_min, filter_max = cls.filter_range
+            valid = (filter_min <= w <= filter_max) and (filter_min <= h <= filter_max)
+            return valid, f"Op has kernel filter WxH as: {w}x{h}"
+        return True, "Op has padding=VALID"
+
+    @classmethod
+    @docstring_format_args(filter_height_range)
+    def constraint_filter_height_range(cls, op):
+        "Kernel filter height must be in the range [{}, {}]"
+        h = op.kernel.height
+        filter_height_min, filter_height_max = cls.filter_height_range
+        valid = filter_height_min <= h <= filter_height_max
+        return valid, f"Op has kernel filter height as: {h}"
+
+    @classmethod
+    @docstring_format_args(filter_product_range)
+    def constraint_filter_product_range(cls, op):
+        "Product of kernel filter width and height must be in the range [{}, {}]"
+        product = op.kernel.elements_wh()
+        filter_product_min, filter_product_max = cls.filter_product_range
+        valid = filter_product_min <= product <= filter_product_max
+        return valid, f"Op has product of kernel filter width and height as: {product}"
+
+    @staticmethod
+    @docstring_format_args(filter_height_range)
+    def constraint_filter_height_range_valid_pad(op):
+        "VALID padding: Kernel filter height must be in the range [{}, {}]"
+        if op.attrs["padding"] == b"VALID":
+            return SupportedOperators.constraint_filter_height_range(op)
+        return True, "Op has padding=SAME"
+
+    @staticmethod
+    @docstring_format_args(filter_product_range)
+    def constraint_filter_product_range_valid_pad(op):
+        "VALID padding: Product of kernel filter width and height must be in the range [{}, {}]"
+        if op.attrs["padding"] == b"VALID":
+            return SupportedOperators.constraint_filter_product_range(op)
+        return True, "Op has padding=SAME"
+
+    @staticmethod
+    def constraint_resize(op):
+        """The width and height of the IFM and OFM must match one of the following criteria:
+        IFM W and H must both be 1
+        IFM must match OFM
+        OFM W and H must be 2x IFM -1, if align_corners is True
+        OFM W and H must be 2x IFM, if align_corners is False"""
+        # Easier to start with False condition as very few cases result in a supported resize
+        valid = False
+        ifm_shape = op.ifm.shape
+        ofm_shape = op.ofm.shape
+        align_corners = op.attrs.get("align_corners", False)
+        if len(ifm_shape) == 4:
+            # Valid if IFM W and H are both 1, or IFM and OFM shape are the same
+            if ((ifm_shape[1] == 1) and (ifm_shape[2] == 1)) or (ifm_shape == ofm_shape):
+                valid = True
+            else:
+                upscaled_shape = np.array(ifm_shape[1:3])
+                out_shape = np.array(ofm_shape[1:3])
+                while (upscaled_shape < out_shape).all():
+                    upscaled_shape *= 2
+                    if align_corners:
+                        upscaled_shape -= 1
+                    # Valid if OFM is 2x IFM (-1 for align corners)
+                    if np.array_equal(out_shape, upscaled_shape):
+                        valid = True
+                        break
+        return valid, f"Op has ifm_shape={ifm_shape}, ofm_shape={ofm_shape} and align_corners={align_corners}"
+
+    @staticmethod
+    def constraint_matching_shapes(op):
+        "IFM and OFM shapes must match"
+        ifm_shape = op.ifm.shape
+        ofm_shape = op.ofm.shape
+        valid = ifm_shape == ofm_shape
+        return valid, f"Op has ifm_shape={ifm_shape} and ofm_shape={ofm_shape}"
+
+    @staticmethod
+    def constraint_splitv_inferred(op):
+        "Only one size is allowed to be inferred"
+        sizes = op.ifm2.values
+        valid = np.count_nonzero(sizes == -1) <= 1
+        return valid, f"Op has multiple inferred sizes (-1): {sizes}"
+
+    @staticmethod
+    def constraint_axis_exists(op):
+        "Axis attribute must exist"
+        axis = op.attrs.get("axis")
+        valid = axis is not None
+        return valid, f"Op has axis={axis}"
+
+    @staticmethod
+    def constraint_axis_valid(op):
+        "Axis attribute must be in the range [0, <ofm_dimensions>)"
+        dims = len(op.ofm.shape)
+        axis = op.attrs["axis"]
+        axis += dims if axis < 0 else 0
+        valid = 0 <= axis < dims
+        return valid, f"Op has ofm_dimensions={dims} and axis attribute is: {axis}"
+
+    @staticmethod
+    def constraint_matching_dimensionality(op):
+        "All Input dimensionalities must match OFM dimensionality"
+        valid = True
+        extra = []
+        ofm_dim = len(op.ofm.shape)
+        tensors = [tens for tens in op.inputs if tens]
+        for tens in tensors:
+            dim = len(tens.shape)
+            if dim != ofm_dim:
+                valid = False
+                extra.append(f"Tensor '{tens.name}' has dimension: {dim}")
+        extra = ", ".join(extra)
+        return valid, f"Op has ofm_dimension={ofm_dim} and the list of mismatching inputs are: {extra}"
+
+    @staticmethod
+    def constraint_valid_dimensions(op):
+        "All Input dimensions must match OFM dimension in all axes except the one defined by the axis attribute"
+        valid = True
+        extra = []
+        ofm_shape = op.ofm.shape
+        ofm_dim = len(ofm_shape)
+        axis = op.attrs["axis"]
+        axis += ofm_dim if axis < 0 else 0
+        tensors = [tens for tens in op.inputs if tens]
+        for tens in tensors:
+            if any(tens.shape[dim] != ofm_shape[dim] for dim in range(ofm_dim) if dim != axis):
+                valid = False
+                extra.append(f"Tensor '{tens.name}' has shape: {tens.shape}")
+        extra = ", ".join(extra)
+        return valid, f"Op has axis={axis}, ofm_shape={ofm_shape} and the list of mismatching inputs are: {extra}"
+
+    @staticmethod
+    def constraint_stridedslice_input_count(op):
+        "Exactly 4 Input tensors are required"
+        inputs = len(op.inputs)
+        valid = inputs == 4
+        return valid, f"Op has {inputs} inputs"
+
+    @staticmethod
+    def constraint_stridedslice_inputs_const(op):
+        "Begin, End and Stride Input tensors must be constant"
+        valid = True
+        extra = []
+        _, begin, end, strides = op.inputs
+        if begin.values is None:
+            valid = False
+            extra.append(f"Begin tensor '{begin.name}'")
+        if end.values is None:
+            valid = False
+            extra.append(f"End tensor '{end.name}'")
+        if strides.values is None:
+            valid = False
+            extra.append(f"Stride tensor '{strides.name}'")
+        extra = ", ".join(extra)
+        return valid, f"Op has non-constant tensors: {extra}"
+
+    @staticmethod
+    def constraint_stridedslice_tens_size_matches(op):
+        "All Input sizes must match OFM size"
+        ifm, begin, end, strides = op.inputs
+        ifm_size = len(ifm.shape)
+        ofm_size = len(op.ofm.shape)
+        begin_size = len(begin.values)
+        end_size = len(end.values)
+        strides_size = len(strides.values)
+        valid = ifm_size == ofm_size == begin_size == end_size == strides_size
+        extra = (
+            f"Op has ofm_size={ofm_size}, ifm_size={ifm_size},"
+            f" begin_size={begin_size}, end_size={end_size} and strides_size={strides_size}"
+        )
         return valid, extra
 
-    @classmethod
-    def check_depthwise_convolution_restrictions(cls, op):
-        # check depth
-        ifm_tensor, ofm_tensor = op.get_ifm_ofm()
-        if op.attrs["depth_multiplier"] > 1 and not (
-            (ifm_tensor.shape[3] == 1) and (ofm_tensor.shape[3] == op.attrs["depth_multiplier"])
-        ):
-            print(
-                "Warning: for depth multipliers > 1,",
-                "number of input channels must be 1 and number of output channels must be equal to depth multiplier.",
-                "Placing on CPU",
-            )
-            return False
-        return True
+    @staticmethod
+    def constraint_stridedslice_stride_values(op):
+        "All Strides values must be 1"
+        strides = op.inputs[3]
+        valid = all(stride == 1 for stride in strides.values)
+        return valid, f"Op has strides values {strides.values}"
 
-    @classmethod
-    def check_transpose_convolution_restrictions(cls, op):
-        # check stride
-        stride_h, stride_w = op.attrs["stride_h"], op.attrs["stride_w"]
-        if stride_h != 2 or stride_w != 2:
-            print("Warning: stride must be equal to 2, placing on CPU")
-            return False
+    @staticmethod
+    def constraint_ellipsis_mask(op):
+        "ellipsis_mask must be 0"
+        ellipsis = op.attrs["ellipsis_mask"]
+        valid = ellipsis == 0
+        return valid, f"Op has ellipsis mask as: {ellipsis}"
 
-        # check output dimensions
-        ifm_tensor, weight_tensor, _, ofm_tensor = op.get_ifm_weights_biases_ofm()
-        ifm_h, ifm_w = ifm_tensor.shape[1], ifm_tensor.shape[2]
-        ofm_h, ofm_w = ofm_tensor.shape[1], ofm_tensor.shape[2]
-        if op.attrs["padding"] == b"SAME":
-            if (ofm_h != ifm_h * stride_h) or (ofm_w != ifm_w * stride_w):
-                print(
-                    "Warning: for",
-                    op.type,
-                    "using SAME padding, output dimensions must equal input dimensions multiplied by stride.",
-                    "Placing on CPU",
-                )
-                return False
-        elif op.attrs["padding"] == b"VALID":
-            kernel_h, kernel_w = weight_tensor.shape[0], weight_tensor.shape[1]
-            if (ofm_h != (ifm_h) * stride_h + max(kernel_h - stride_h, 0)) or (
-                ofm_w != (ifm_w) * stride_w + max(kernel_w - stride_w, 0)
-            ):
-                print(
-                    "Warning: for",
-                    op.type,
-                    "using VALID padding, output dimensions must equal input dimensions multiplied by stride,",
-                    "minus difference between kernel size and stride. Placing on CPU",
-                )
-                return False
-        return True
+    @staticmethod
+    def constraint_axis_masks(op):
+        "new_axis_mask and shrink_axis_mask cannot both be set"
+        new_axis = op.attrs["new_axis_mask"]
+        shrink_axis = op.attrs["shrink_axis_mask"]
+        valid = (new_axis == 0) or (shrink_axis == 0)
+        return valid, f"Op has new_axis_mask={new_axis} and shrink_axis_mask={shrink_axis}"
 
-    @classmethod
-    def check_pooling_restrictions(cls, op):
-        # check stride
-        stride_w, stride_h = op.attrs["stride_w"], op.attrs["stride_h"]
-        if not is_integer(stride_w) or not is_integer(stride_h):
-            print("Warning:", op.type, "has non-integer stride, placing on CPU")
-            return False
-        if not 1 <= stride_w <= 3 or not 1 <= stride_h <= 3:
-            print(
-                "Warning: {} has stride ({}, {}), only strides in range [1, 3] are allowed. Placing on CPU".format(
-                    op.type, stride_w, stride_h
-                )
-            )
-            return False
+    @staticmethod
+    def constraint_slice_ranges(op):
+        "Slice 'end' values must be greater than 'begin' values"
+        ifm, begin, end, _ = op.inputs
+        # Calculate offset begin/end
+        offset_begin = get_slice_offsets(ifm.shape, begin, op.attrs["begin_mask"], is_begin=True)
+        offset_end = get_slice_offsets(ifm.shape, end, op.attrs["end_mask"], is_begin=False)
+        # Check "end - begin" doesn't result in any zero or negative elements
+        valid = all((e - b) > 0 for b, e in zip(offset_begin, offset_end))
+        return valid, f"Op has begin_values={begin.values} and end_values={end.values}"
 
-        # check data type
-        ifm_tensor, ofm_tensor = op.get_ifm_ofm()
-        if ifm_tensor.dtype != ofm_tensor.dtype:
-            if op.type != Op.ReduceSum:
-                print("Warning: input data type doesn't match output data type, placing on CPU")
-                return False
-            # TODO: else check ReduceSum restrictions.
+    @staticmethod
+    def constraint_matching_inputs_types(op):
+        "Both Input data types must match"
+        ifm_dtype = op.ifm.dtype
+        ifm2_dtype = op.ifm2.dtype
+        valid = ifm_dtype == ifm2_dtype
+        return valid, f"Op has ifm_dtype={ifm_dtype} and ifm2_dtype={ifm2_dtype}"
 
-        # check batch size
-        if ifm_tensor.shape[0] != 1:
-            print("Warning: input batch size must be 1, placing on CPU")
-            return False
+    @staticmethod
+    def constraint_matching_signed(op):
+        "For IFM that are signed, OFM must also be signed"
+        valid = True
+        ifm_dtype = op.ifm.dtype
+        ofm_dtype = op.ofm.dtype
+        if ifm_dtype.type & BaseType.Signed:
+            valid = bool(ofm_dtype.type & BaseType.Signed)
+        return valid, f"Op has ifm_dtype={ifm_dtype} and ofm_dtype={ofm_dtype}"
 
-        # check kernel size
-        kernel_w, kernel_h = op.attrs["filter_width"], op.attrs["filter_height"]
-        if op.type in cls.avg_pooling_ops and op.attrs["padding"] == b"SAME":
-            if not 1 <= kernel_w <= 8 or not 1 <= kernel_h <= 8:
-                print(
-                    "Warning:",
-                    op.type,
-                    "has kernel size ({}, {}), only kernel sizes in range [1, 8] are allowed. Placing on CPU".format(
-                        kernel_w, kernel_h
-                    ),
-                )
-                return False
-        if op.type in cls.avg_pooling_ops and op.attrs["padding"] == b"VALID" or op.type in cls.max_pooling_ops:
-            if not 1 <= kernel_w * kernel_h <= 256 * 256:
-                print(
-                    "Warning: product of kernel width and height must be >= 1 and not exceed 256 * 256 ({}),".format(
-                        256 * 256
-                    ),
-                    "placing on CPU",
-                )
-                return False
-            if not 1 <= kernel_h <= 256:
-                print("Warning:", op.type, "has kernel height outside of range [1, 256], placing on CPU")
-                return False
+    @staticmethod
+    def constraint_unsigned_valid(op):
+        "For IFM that are unsigned, OFM must either be the same type or int32"
+        valid = True
+        ifm_dtype = op.ifm.dtype
+        ofm_dtype = op.ofm.dtype
+        if ifm_dtype.type & BaseType.Unsigned:
+            valid = (ifm_dtype == ofm_dtype) or (ofm_dtype == DataType.int32)
+        return valid, f"Op has ifm_dtype={ifm_dtype} and ofm_dtype={ofm_dtype}"
 
-        return True
+    @staticmethod
+    def constraint_inputs_int32(op):
+        "Both Input data types must be int32"
+        ifm_dtype = op.ifm.dtype
+        ifm2_dtype = op.ifm2.dtype
+        valid = (ifm_dtype == DataType.int32) and (ifm2_dtype == DataType.int32)
+        return valid, f"Op has ifm_dtype={ifm_dtype} and ifm2_dtype={ifm2_dtype}"
 
-    @classmethod
-    def check_resize_restrictions(cls, op):
-        # check unsupported upscaling factor
-        if op.type == Op.ResizeBilinear:
-            if op.inputs[0].shape[1] == 1 and op.inputs[0].shape[2] == 1:
-                return True
-            if op.inputs[0].shape == op.outputs[0].shape:
-                return True
-            upscaled_shape = np.array(op.inputs[0].shape[1:3])
-            out_shape = np.array(op.outputs[0].shape[1:3])
-            while (upscaled_shape < out_shape).all():
-                upscaled_shape *= 2
-                if op.attrs["align_corners"]:
-                    upscaled_shape -= 1
-                if np.array_equal(out_shape, upscaled_shape):
-                    return True
-        return False
+    @staticmethod
+    def constraint_output_int32(op):
+        "OFM must be int32"
+        ofm_dtype = op.ofm.dtype
+        valid = ofm_dtype == DataType.int32
+        return valid, f"Op has ofm_dtype={ofm_dtype}"
 
-    @classmethod
-    def check_vector_product_restrictions(cls, op):
-        # check data type
-        ifm_tensor, _, weight_tensor, bias_tensor, _ = op.get_ifm_ifm2_weights_biases_ofm()
-        if weight_tensor.element_size() > 1:
-            print("Warning: only 8-bit datatypes supported for {}, placing on CPU".format(op.type))
-            return False
+    @staticmethod
+    def constraint_matching_quantization_parameters(op):
+        "Both Input quantization parameters must match OFM quantization parameters"
+        valid = True
+        extra = []
+        if not check_quantized_tens_scaling_equal(op.ofm, op.ifm):
+            valid = False
+            extra.append(op.ifm.name)
+        if not check_quantized_tens_scaling_equal(op.ofm, op.ifm2):
+            valid = False
+            extra.append(op.ifm2.name)
+        extra = ", ".join(extra)
+        return valid, f"Op has tensors with different quantization parameters to the OFM '{op.ofm.name}': {extra}"
 
-        if not cls.check_bias_restrictions(bias_tensor):
-            return False
+    @staticmethod
+    def constraint_elemwise_batch_size(op):
+        "Batch size must be 1 for Input tensors with more than 2 dimensions"
+        valid = True
+        extra = []
+        for tens in (op.ifm, op.ifm2):
+            # Unary ops have ifm2 as None
+            if tens is not None:
+                if (len(tens.shape) > 2) and (tens.shape[0] != 1):
+                    valid = False
+                    extra.append(tens.name)
+        extra = ", ".join(extra)
+        return valid, f"Op has invalid input tensors: {extra}"
 
-        # check non const weights
-        if weight_tensor.values is None:
-            print("Warning:", op.type, "has non-const weights, placing on CPU")
-            return False
+    @staticmethod
+    def constraint_matching_either_shapes(op):
+        "At least one Input's shape must match the OFM's shape"
+        ifm_shape = op.ifm.shape
+        ifm2_shape = op.ifm2.shape if op.ifm2 else None
+        ofm_shape = op.ofm.shape
+        valid = (ifm_shape == ofm_shape) or (ifm2_shape == ofm_shape)
+        return valid, f"Op has ifm_shape={ifm_shape}, ifm2_shape={ifm2_shape} and ofm_shape={ofm_shape}"
 
-        return True
-
-    @classmethod
-    def check_element_wise_restrictions(cls, op):
-        # check data type
-        ifm_tensor, ifm2_tensor, _, ofm_tensor = op.get_ifm_ifm2_weights_ofm()
-        # input and output datatype must match for these operators
-        if (
-            op.type in cls.binary_elem_wise_min_max_ops | cls.unary_elem_wise_main_ops
-            and ifm_tensor.dtype != ofm_tensor.dtype
-        ):
-            print("Warning:", op.type, "must have same input and output datatype, placing on CPU")
-            return False
-        if op.type in cls.binary_elem_wise_add_mul_sub:
-            # both inputs must have same type
-            if ifm_tensor.dtype != ifm2_tensor.dtype:
-                print("Warning:", op.type, "must have same datatype on both inputs, placing on CPU")
-                return False
-            # signed input check
-            if ifm_tensor.dtype.type & BaseType.Signed:
-                # output must be signed
-                if ofm_tensor.dtype.type & BaseType.Unsigned:
-                    print("Warning: only signed output types supported for {}, placing on CPU".format(op.type))
-                    return False
-                # and 8, 16 or 32-bit
-                bit_lengths = {8, 16, 32}
-                if ofm_tensor.element_size() * 8 not in bit_lengths:
-                    print(
-                        "Warning:", op.type, "is only supported for bit lengths {}, placing on CPU".format(bit_lengths)
-                    )
-                    return False
-            # unsigned input check, output must be same type or int32
-            if ifm_tensor.dtype.type & BaseType.Unsigned and not (
-                ifm_tensor.dtype == ofm_tensor.dtype or ofm_tensor.dtype == DataType.int32
-            ):
-                print("Warning:", op.type, "has unsigned input but output is not unsigned or int32, placing on CPU")
-                return False
-        elif op.type in cls.binary_elem_wise_shift_ops:
-            if ifm_tensor.dtype != DataType.int32 or ifm2_tensor.dtype != DataType.int32:
-                print("Warning:", op.type, "input datatypes are not int32, placing on CPU")
-                return False
-            if op.type in (Op.CLZ, Op.SHL) and ofm_tensor.dtype != DataType.int32:
-                print("Warning:", op.type, "output datatype is not int32, placing on CPU")
-                return False
-
-        # check batch size
-        if len(ifm_tensor.shape) > 2 and ifm_tensor.shape[0] != 1:
-            print(
-                "Warning:",
-                op.type,
-                "only supports batch size 1 for tensors with more than 2 dimensions, placing on CPU",
-            )
-            return False
-        if op.type in cls.binary_elem_wise_main_ops:  # if op type is unary, ifm2_tensor is None
-            if len(ifm2_tensor.shape) > 2 and ifm2_tensor.shape[0] != 1:
-                print(
-                    "Warning:",
-                    op.type,
-                    "only supports batch size 1 for tensors with more than 2 dimensions, placing on CPU",
-                )
-                return False
-
-        # negative alpha values are not supported
-        if op.type == Op.LeakyRelu and op.attrs["alpha"] < 0:
-            print("Warning:", op.type, "has negative alpha, placing on CPU")
-            return False
-
-        # check if ifm or ifm2 has ofm shape
-        if ifm_tensor.shape != ofm_tensor.shape and ifm2_tensor.shape != ofm_tensor.shape:
-            print("Warning:", op.type, "input shape(s) differ from output shape, placing on CPU")
-            return False
-
-        if op.type in cls.binary_elem_wise_min_max_ops and not cls.check_quantization_restrictions_binary_elem_wise(op):
-            return False
-
-        return True
-
-    @classmethod
-    def check_memory_only_restrictions(cls, op):
-        if op.type == Op.StridedSlice:
-            if len(op.inputs) != 4:
-                warn_cpu(op, "has {} input tensors, only 4 inputs are supported".format(len(op.inputs)))
-                return False
-            input_tens, begin_tens, end_tens, strides_tens = op.inputs
-            if begin_tens.values is None or end_tens.values is None or strides_tens.values is None:
-                warn_cpu(op, "has a non-constant begin, end, or stride input tensor, which is not supported")
-                return False
-            if not (
-                len(input_tens.shape)
-                == len(op.outputs[0].shape)
-                == len(begin_tens.values)
-                == len(end_tens.values)
-                == len(strides_tens.values)
-            ):
-                warn_cpu(op, "has input tensors with shapes that are not supported")
-                return False
-            # check stride size
-            if any(stride != 1 for stride in strides_tens.values):
-                warn_cpu(op, "has stride values {}, only stride 1 values are supported".format(strides_tens.values))
-                return False
-            # check ellipsis_mask
-            if op.attrs["ellipsis_mask"] != 0:
-                warn_cpu(op, "ellipsis_mask is {}, only 0 is supported".format(op.attrs["ellipsis_mask"]))
-                return False
-            # check if both new_axis_mask and shrink_axis_mask have bit set
-            if op.attrs["new_axis_mask"] != 0 and op.attrs["shrink_axis_mask"] != 0:
-                warn_cpu(op, "new_axis_mask and shrink_axis_mask are both non-zero, which is not supported")
-                return False
-            # Calculate offset start/end
-            offset_start = get_slice_offsets(input_tens.shape, begin_tens, op.attrs["begin_mask"], is_begin=True)
-            offset_end = get_slice_offsets(input_tens.shape, end_tens, op.attrs["end_mask"], is_begin=False)
-            # check "end - begin" doesn't result in any zero or negative elements
-            if any((end - begin) <= 0 for begin, end in zip(offset_start, offset_end)):
-                warn_cpu(
-                    op,
-                    "has slice begin values {}, some of which are >= end values {}, which is illegal".format(
-                        begin_tens.values, end_tens.values
-                    ),
-                )
-                return False
-        if op.type == Op.SplitV:
-            # check that maximum one size is set to -1, indicating that size should be inferred
-            sizes = op.inputs[1].values
-            num_to_be_inferred = 0
-            for size in sizes:
-                if size == -1:
-                    num_to_be_inferred += 1
-
-            if num_to_be_inferred > 1:
-                print("Warning:", op.type, "has more than one size to be inferred, which is illegal, placing on CPU")
-                return False
-        if op.type in set((Op.Concat, Op.ConcatTFLite,)):
-            axis = op.attrs.get("axis", None)
-            if axis is None:
-                print("Warning:", op.type, "invalid or missing axis, placing on CPU")
-                return False
-            if axis < 0:
-                axis += len(op.inputs[0].shape)
-            if not 0 <= axis < len(op.inputs[0].shape):
-                print("Warning:", op.type, "invalid axis", axis, ", placing on CPU")
-                return False
-            ofm = op.outputs[0]
-            ofm_dims = len(ofm.shape)
-            for ifm in op.inputs:
-                if len(ifm.shape) != ofm_dims:
-                    return False
-                for i in range(ofm_dims):
-                    if i != axis and ifm.shape[i] != ofm.shape[i]:
-                        print(
-                            "Warning:",
-                            op.type,
-                            "invalid ifm:",
-                            ifm.name,
-                            ifm.shape,
-                            "mismatch in dimension",
-                            i,
-                            ", placing on CPU",
-                        )
-                        return False
-
-        return True
-
-    @classmethod
-    def check_quantization_restrictions_binary_elem_wise(cls, op):
-        # checks that IFM1, IFM2 and OFM quantization are equal for binary ops
-
-        assert len(op.inputs) >= 2 and len(op.outputs) == 1
-
-        if (
-            not check_tens_quantized(op.inputs[0])
-            or not check_tens_quantized(op.inputs[1])
-            or not check_tens_quantized(op.outputs[0])
-        ):
-            warn_cpu(op, "has non-quantised input and/or output tensors")
-            return False
-
-        if not check_quantized_tens_scaling_equal(op.inputs[0], op.inputs[1]) or not check_quantized_tens_scaling_equal(
-            op.inputs[0], op.outputs[0]
-        ):
-            warn_cpu(op, "has input/output tensors with different quantisation which is illegal")
-            return False
-
-        return True
-
-    @classmethod
-    def check_activation_ops(cls, op):
-        if op.type == Op.Softmax:
-            ifm_tensor = op.inputs[0]
-            ofm_tensor = op.outputs[0]
-
-            # check data type
-            if ifm_tensor.dtype != ofm_tensor.dtype:
-                print("Warning:", op.type, "input type differs from output type, placing on CPU")
-                return False
-
-            if ifm_tensor.dtype not in (DataType.uint8, DataType.int8, DataType.int16):
-                print(
-                    "Warning: only datatypes supported for {} are uint8, int8 and int16; placing on CPU".format(op.type)
-                )
-                return False
-
-            # check shape
-            if ifm_tensor.shape != ofm_tensor.shape:
-                print("Warning:", op.type, "input shape differs from output shape, placing on CPU")
-                return False
-
-        elif op.type.is_relu_op():
-            ifm_tensor, ofm_tensor = op.get_ifm_ofm()
-            if np.isinf(ifm_tensor.quantization.scale_f32 / ofm_tensor.quantization.scale_f32):
-                print("Warning:", op.type, "has an infinite scale value, placing on CPU")
-                return False
-
-        return True
-
-    @classmethod
-    def check_bias_restrictions(cls, bias_tensor):
-        # check data type
-        if bias_tensor is not None and bias_tensor.dtype not in (DataType.int32, DataType.int64):
-            print("Warning: bias tensor datatype must be int32 or int64, placing on CPU")
-            return False
-
-        # check if values fits in 40-bit
-        if bias_tensor is not None and bias_tensor.dtype == DataType.int64:
-            for quant_value in bias_tensor.quant_values:
-                if not (-(1 << 39) <= quant_value < (1 << 39)):
-                    print("Warning: bias tensor values are larger than 40 bits, placing on CPU")
-                    return False
-
-        return True
+    @staticmethod
+    def constraint_alpha_valid(op):
+        "Alpha must not be negative"
+        alpha = op.attrs["alpha"]
+        valid = alpha >= 0
+        return valid, f"Op has alpha={alpha}"