vela: SupportedOperators promote to class instance

Part of larger refactoring. The sets of operators do not need to be
instance attributes and are not expected to be modified at runtime.
This in turn allows almost all functions to become class methods.

Signed-off-by: Michael McGeagh <michael.mcgeagh@arm.com>
Change-Id: I7dc24d65cdd6c4bda641b3d6133b3134302a552f
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py
index 73a4f28..867613c 100644
--- a/ethosu/vela/supported_operators.py
+++ b/ethosu/vela/supported_operators.py
@@ -27,104 +27,103 @@
 
 
 class SupportedOperators:
+    # Categorised lists of supported operators
+    npu_pre_ops = set(("QuantizedResizeBilinear", "SplitSliceRead",))
+    convolution_ops = set(("Conv2DBiasAct", "Conv2D", "QuantizedConv2D",))
+    depthwise_convolution_ops = set(("DepthwiseConv2dBiasAct", "DepthwiseConv2dNative", "QuantizedDepthwiseConv2D",))
+    transpose_convolution_ops = set(("Conv2DBackpropInput",))
+    max_pooling_ops = set(("QuantizedMaxPool", "MaxPool", "MaxPoolAct",))
+    avg_pooling_ops = set(("QuantizedAvgPool", "AvgPool", "AvgPoolAct",))
+    pooling_ops = set(("ReduceSum",)) | max_pooling_ops | avg_pooling_ops
+    resizing_ops = set(("ResizeBilinear",))
+    fc_vector_products = set(("QuantizedMatMul", "MatMul", "FullyConnectedAct",))
+    mac_main_ops = (
+        # RNN/LSTM/GRU
+        set(("BlockLSTM",))
+        # convolutions
+        | convolution_ops
+        # depth-wise convolutions
+        | depthwise_convolution_ops
+        # transpose convolutions
+        | transpose_convolution_ops
+        # pooling
+        | pooling_ops
+        # resizing/upscaling
+        | resizing_ops
+        # FC layers
+        | fc_vector_products
+    )
+    unary_elem_wise_main_ops = set(("LeakyRelu", "Abs", "CLZ",))
+    binary_elem_wise_min_max_ops = set(("Minimum", "Maximum",))
+    binary_elem_wise_shift_ops = set(("SHL", "SHR",))
+    binary_elem_wise_add_mul_sub = set(
+        ("AddAct", "MulAct", "SubAct", "QuantizedAdd", "QuantizedSub", "QuantizedMul", "Mul", "Add", "Sub",)
+    )
+    binary_elem_wise_main_ops = binary_elem_wise_min_max_ops | binary_elem_wise_add_mul_sub | binary_elem_wise_shift_ops
+    elem_wise_main_ops = binary_elem_wise_main_ops | unary_elem_wise_main_ops
+    activation_ops = set(
+        (
+            "QuantizedRelu",
+            "QuantizedRelu1",
+            "QuantizedRelu6",
+            "Relu",
+            "Relu6",
+            "ReluN1To1",
+            "Sigmoid",
+            "Tanh",
+            "Softmax",
+        )
+    )
+    npu_post_ops = (
+        # concatenation write direction
+        set(("ConcatSliceWrite",))
+        # bias add and batch norm
+        | set(("QuantizedBiasAdd", "Requantize", "QuantizedBatchNorm", "BiasAdd", "FusedBatchNorm",))
+        # Quantization
+        | set(("Quantize",))
+        # activation functions
+        | activation_ops
+    )
+    split_ops = set(("Split", "SplitV", "StridedSlice", "Slice", "UnpackReshaped", "Unpack",))
+    concat_ops = set(("Concat", "ConcatV2", "QuantizedConcat", "ConcatTFLite", "PackReshaped", "Pack",))
+    memory_only_ops = set(("Squeeze", "Reshape", "QuantizedReshape", "ExpandDims",)) | concat_ops | split_ops
+    shapeless_input_ops = set(("Split", "SplitV",)) | binary_elem_wise_main_ops
+    supported_fused_activations = set(("Relu", "Relu6", "ReluN1To1", "Tanh", "Sigmoid", "LUT",))
+    supported_operators = npu_pre_ops | mac_main_ops | elem_wise_main_ops | npu_post_ops | memory_only_ops
+
     def __init__(self):
-        # Categorised lists of supported operators
-        self.npu_pre_ops = set(("QuantizedResizeBilinear", "SplitSliceRead",))
-        self.convolution_ops = set(("Conv2DBiasAct", "Conv2D", "QuantizedConv2D",))
-        self.depthwise_convolution_ops = set(
-            ("DepthwiseConv2dBiasAct", "DepthwiseConv2dNative", "QuantizedDepthwiseConv2D,")
-        )
-        self.transpose_convolution_ops = set(("Conv2DBackpropInput",))
-        self.max_pooling_ops = set(("QuantizedMaxPool", "MaxPool", "MaxPoolAct",))
-        self.avg_pooling_ops = set(("QuantizedAvgPool", "AvgPool", "AvgPoolAct",))
-        self.pooling_ops = set(("ReduceSum",)) | self.max_pooling_ops | self.avg_pooling_ops
-        self.resizing_ops = set(("ResizeBilinear",))
-        self.fc_vector_products = set(("QuantizedMatMul", "MatMul", "FullyConnectedAct",))
-        self.mac_main_ops = (
-            # convolutions
-            self.convolution_ops
-            # depth-wise convolutions
-            | self.depthwise_convolution_ops
-            # transpose convolutions
-            | self.transpose_convolution_ops
-            # pooling
-            | self.pooling_ops
-            # resizing/upscaling
-            | self.resizing_ops
-            # FC layers
-            | self.fc_vector_products
-            # RNN/LSTM/GRU
-            | set(("BlockLSTM",))
-        )
-        self.unary_elem_wise_main_ops = set(("LeakyRelu", "Abs", "CLZ",))
-        self.binary_elem_wise_min_max_ops = set(("Minimum", "Maximum",))
-        self.binary_elem_wise_shift_ops = set(("SHL", "SHR",))
-        self.binary_elem_wise_add_mul_sub = set(
-            ("AddAct", "MulAct", "SubAct", "QuantizedAdd", "QuantizedSub", "QuantizedMul", "Mul", "Add", "Sub",)
-        )
-        self.binary_elem_wise_main_ops = (
-            self.binary_elem_wise_min_max_ops | self.binary_elem_wise_add_mul_sub | self.binary_elem_wise_shift_ops
-        )
-        self.elem_wise_main_ops = self.binary_elem_wise_main_ops | self.unary_elem_wise_main_ops
-        self.activation_ops = set(
-            (
-                "QuantizedRelu",
-                "QuantizedRelu1",
-                "QuantizedRelu6",
-                "Relu",
-                "Relu6",
-                "ReluN1To1",
-                "Sigmoid",
-                "Tanh",
-                "Softmax",
-            )
-        )
-        self.npu_post_ops = (
-            # activation functions
-            self.activation_ops
-            # concatenation write direction
-            | set(("ConcatSliceWrite",))
-            # bias add and batch norm
-            | set(("QuantizedBiasAdd", "Requantize", "QuantizedBatchNorm", "BiasAdd", "FusedBatchNorm",))
-            # Quantization
-            | set(("Quantize",))
-        )
-        self.split_ops = set(("Split", "SplitV", "StridedSlice", "Slice", "UnpackReshaped", "Unpack",))
-        self.concat_ops = set(("Concat", "ConcatV2", "QuantizedConcat", "ConcatTFLite", "PackReshaped", "Pack",))
-        self.memory_only_ops = (
-            set(("Squeeze", "Reshape", "QuantizedReshape", "ExpandDims",)) | self.concat_ops | self.split_ops
-        )
-        self.shapeless_input_ops = self.binary_elem_wise_main_ops | set(("Split", "SplitV",))
-        self.supported_fused_activations = set(("Relu", "Relu6", "ReluN1To1", "Tanh", "Sigmoid", "LUT",))
-        self.supported_operators = (
-            self.npu_pre_ops | self.mac_main_ops | self.elem_wise_main_ops | self.npu_post_ops | self.memory_only_ops
-        )
         # Setup supported operator restriction checkers
         self.supported_operator_restrictions = {}
         self.supported_operator_restrictions.update(
-            {op: self.check_convolution_restrictions for op in self.convolution_ops}
+            {op: self.check_convolution_restrictions for op in SupportedOperators.convolution_ops}
         )
         self.supported_operator_restrictions.update(
-            {op: self.check_depthwise_convolution_restrictions for op in self.depthwise_convolution_ops}
+            {op: self.check_depthwise_convolution_restrictions for op in SupportedOperators.depthwise_convolution_ops}
         )
         self.supported_operator_restrictions.update(
-            {op: self.check_transpose_convolution_restrictions for op in self.transpose_convolution_ops}
-        )
-        self.supported_operator_restrictions.update({op: self.check_pooling_restrictions for op in self.pooling_ops})
-        self.supported_operator_restrictions.update({op: self.check_resize_restrictions for op in self.resizing_ops})
-        self.supported_operator_restrictions.update(
-            {op: self.check_vector_product_restrictions for op in self.fc_vector_products}
+            {op: self.check_transpose_convolution_restrictions for op in SupportedOperators.transpose_convolution_ops}
         )
         self.supported_operator_restrictions.update(
-            {op: self.check_element_wise_restrictions for op in self.elem_wise_main_ops}
+            {op: self.check_pooling_restrictions for op in SupportedOperators.pooling_ops}
         )
         self.supported_operator_restrictions.update(
-            {op: self.check_memory_only_restrictions for op in self.memory_only_ops}
+            {op: self.check_resize_restrictions for op in SupportedOperators.resizing_ops}
         )
-        self.supported_operator_restrictions.update({op: self.check_activation_ops for op in self.activation_ops})
+        self.supported_operator_restrictions.update(
+            {op: self.check_vector_product_restrictions for op in SupportedOperators.fc_vector_products}
+        )
+        self.supported_operator_restrictions.update(
+            {op: self.check_element_wise_restrictions for op in SupportedOperators.elem_wise_main_ops}
+        )
+        self.supported_operator_restrictions.update(
+            {op: self.check_memory_only_restrictions for op in SupportedOperators.memory_only_ops}
+        )
+        self.supported_operator_restrictions.update(
+            {op: self.check_activation_ops for op in SupportedOperators.activation_ops}
+        )
 
     def is_operator_supported(self, op):
-        if op.type not in self.supported_operators:
+        if op.type not in SupportedOperators.supported_operators:
             return False
         if not self.check_generic_restrictions(op):
             return False
@@ -132,7 +131,8 @@
             return self.supported_operator_restrictions[op.type](op)
         return True
 
-    def check_generic_restrictions(self, op):
+    @classmethod
+    def check_generic_restrictions(cls, op):
         # check fully defined shapes
         for t in op.inputs:
             if not t:
@@ -140,7 +140,7 @@
             if not t.has_fully_defined_shape():
                 print("Warning:", op.type, "has input(s) of undefined shape, placing on CPU")
                 return False
-            if t.shape == [] and op.type not in self.shapeless_input_ops:
+            if t.shape == [] and op.type not in cls.shapeless_input_ops:
                 print(
                     "Warning:",
                     op.type,
@@ -180,8 +180,8 @@
                 t.element_size() > 2
                 and op.type
                 not in set(("Requantize", "ReduceSum", "CLZ",))
-                | self.binary_elem_wise_add_mul_sub
-                | self.binary_elem_wise_shift_ops
+                | cls.binary_elem_wise_add_mul_sub
+                | cls.binary_elem_wise_shift_ops
             ):
                 return False
             # check size
@@ -192,7 +192,7 @@
         if (
             "fused_activation_function" in op.attrs
             and op.attrs["fused_activation_function"] is not None
-            and op.attrs["fused_activation_function"] not in self.supported_fused_activations
+            and op.attrs["fused_activation_function"] not in cls.supported_fused_activations
         ):
             return False
 
@@ -209,7 +209,8 @@
 
         return True
 
-    def check_convolution_restrictions(self, op):
+    @classmethod
+    def check_convolution_restrictions(cls, op):
         # check stride
         if op.attrs["stride_w"] > 3 or op.attrs["stride_h"] > 3:
             return False
@@ -225,7 +226,7 @@
         if weight_tensor.element_size() > 1:
             return False
 
-        if not self.check_bias_restrictions(bias_tensor):
+        if not cls.check_bias_restrictions(bias_tensor):
             return False
 
         # check kernel size [HWIO]
@@ -255,16 +256,18 @@
 
         return True
 
-    def check_depthwise_convolution_restrictions(self, op):
+    @classmethod
+    def check_depthwise_convolution_restrictions(cls, op):
         # check depth
         ifm_tensor, _, _, ofm_tensor = op.get_ifm_ifm2_weights_ofm()
         if op.attrs["depth_multiplier"] > 1 and not (
             (ifm_tensor.shape[3] == 1) and (ofm_tensor.shape[3] == op.attrs["depth_multiplier"])
         ):
             return False
-        return self.check_convolution_restrictions(op)
+        return cls.check_convolution_restrictions(op)
 
-    def check_transpose_convolution_restrictions(self, op):
+    @classmethod
+    def check_transpose_convolution_restrictions(cls, op):
         # check stride
         stride_h, stride_w = op.attrs["stride_h"], op.attrs["stride_w"]
         if stride_h != stride_w != 2:
@@ -284,9 +287,10 @@
             ):
                 return False
 
-        return self.check_convolution_restrictions(op)
+        return cls.check_convolution_restrictions(op)
 
-    def check_pooling_restrictions(self, op):
+    @classmethod
+    def check_pooling_restrictions(cls, op):
         # check stride
         if op.attrs["stride_w"] > 3 or op.attrs["stride_h"] > 3:
             return False
@@ -302,7 +306,7 @@
         if ifm_tensor.shape[0] != 1:
             return False
 
-        if op.type in self.avg_pooling_ops:
+        if op.type in cls.avg_pooling_ops:
             # check kernel size
             if op.attrs["padding"] == b"SAME" and (op.attrs["filter_width"] > 8 or op.attrs["filter_height"] > 8):
                 return False
@@ -311,13 +315,14 @@
             ):
                 return False
 
-        if op.type in self.max_pooling_ops:
+        if op.type in cls.max_pooling_ops:
             # check kernel size (any padding)
             if op.attrs["filter_width"] * op.attrs["filter_height"] > 256 * 256 or op.attrs["filter_height"] > 256:
                 return False
         return True
 
-    def check_resize_restrictions(self, op):
+    @classmethod
+    def check_resize_restrictions(cls, op):
         # check unsupported upscaling factor
         if op.type == "ResizeBilinear":
             if op.inputs[0].shape[1] == 1 and op.inputs[0].shape[2] == 1:
@@ -334,13 +339,14 @@
                     return True
         return False
 
-    def check_vector_product_restrictions(self, op):
+    @classmethod
+    def check_vector_product_restrictions(cls, op):
         # check data type
         _, _, weight_tensor, bias_tensor, _ = op.get_ifm_ifm2_weights_biases_ofm()
         if weight_tensor.element_size() > 1:
             return False
 
-        if not self.check_bias_restrictions(bias_tensor):
+        if not cls.check_bias_restrictions(bias_tensor):
             return False
 
         # check non const weights
@@ -350,16 +356,17 @@
 
         return True
 
-    def check_element_wise_restrictions(self, op):
+    @classmethod
+    def check_element_wise_restrictions(cls, op):
         # check data type
         ifm_tensor, ifm2_tensor, _, ofm_tensor = op.get_ifm_ifm2_weights_ofm()
         # input and output datatype must match for these operators
         if (
-            op.type in self.binary_elem_wise_min_max_ops | self.unary_elem_wise_main_ops
+            op.type in cls.binary_elem_wise_min_max_ops | cls.unary_elem_wise_main_ops
             and ifm_tensor.dtype != ofm_tensor.dtype
         ):
             return False
-        if op.type in self.binary_elem_wise_add_mul_sub:
+        if op.type in cls.binary_elem_wise_add_mul_sub:
             # both inputs must have same type
             if ifm_tensor.dtype != ifm2_tensor.dtype:
                 return False
@@ -376,7 +383,7 @@
                 ifm_tensor.dtype == ofm_tensor.dtype or ofm_tensor.dtype == DataType.int32
             ):
                 return False
-        elif op.type in self.binary_elem_wise_shift_ops | set(("CLZ")):
+        elif op.type in cls.binary_elem_wise_shift_ops | set(("CLZ")):
             if ifm_tensor.dtype != DataType.int32 or ifm2_tensor.dtype != DataType.int32:
                 return False
             if op.type in ("CLZ", "SHL") and ofm_tensor.dtype != DataType.int32:
@@ -385,7 +392,7 @@
         # check batch size
         if len(ifm_tensor.shape) > 2 and ifm_tensor.shape[0] != 1:
             return False
-        if op.type in self.binary_elem_wise_main_ops:  # if op type is unary, ifm2_tensor is None
+        if op.type in cls.binary_elem_wise_main_ops:  # if op type is unary, ifm2_tensor is None
             if len(ifm2_tensor.shape) > 2 and ifm2_tensor.shape[0] != 1:
                 return False
 
@@ -397,14 +404,13 @@
         if ifm_tensor.shape != ofm_tensor.shape and ifm2_tensor.shape != ofm_tensor.shape:
             return False
 
-        if op.type in self.binary_elem_wise_min_max_ops and not self.check_quantization_restrictions_binary_elem_wise(
-            op
-        ):
+        if op.type in cls.binary_elem_wise_min_max_ops and not cls.check_quantization_restrictions_binary_elem_wise(op):
             return False
 
         return True
 
-    def check_memory_only_restrictions(self, op):
+    @classmethod
+    def check_memory_only_restrictions(cls, op):
         if op.type == "StridedSlice":
             if len(op.inputs) != 4:
                 warn_cpu(op, "has {} input tensors, only 4 inputs are supported".format(len(op.inputs)))
@@ -488,7 +494,8 @@
 
         return True
 
-    def check_quantization_restrictions_binary_elem_wise(self, op):
+    @classmethod
+    def check_quantization_restrictions_binary_elem_wise(cls, op):
         # makes sure IFM1, IFM2 and OFM quantization are equal for binary ops
         assert len(op.inputs) >= 2 and len(op.outputs) == 1
 
@@ -504,7 +511,8 @@
 
         return True
 
-    def check_activation_ops(self, op):
+    @classmethod
+    def check_activation_ops(cls, op):
         if op.type == "Softmax":
             ifm_tensor = op.inputs[0]
             ofm_tensor = op.outputs[0]
@@ -522,7 +530,8 @@
 
         return True
 
-    def check_bias_restrictions(self, bias_tensor):
+    @classmethod
+    def check_bias_restrictions(cls, bias_tensor):
         # check data type
         if bias_tensor is not None and bias_tensor.dtype not in (DataType.int32, DataType.int64):
             return False