[MLBEDSW-2335] SoftMax int16

Added graph rewrite of Softmax for int16.

Change-Id: Id7885af6056a23e8b8362fb61ae94283251eb398
Signed-off-by: Fredrik Svedberg <fredrik.svedberg@arm.com>
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py
index 3ec3429..73e219b 100644
--- a/ethosu/vela/supported_operators.py
+++ b/ethosu/vela/supported_operators.py
@@ -20,19 +20,20 @@
 
 
 class SupportedOperators:
-    def __init__(self):
+    def __init__(self, softmax_support):
+        self.softmax_support = softmax_support
         # Categorised lists of supported operators
-        self.npu_pre_ops = set(("QuantizedResizeBilinear", "SplitSliceRead"))
-        self.convolution_ops = set(("Conv2DBiasAct", "Conv2D", "QuantizedConv2D"))
+        self.npu_pre_ops = set(("QuantizedResizeBilinear", "SplitSliceRead",))
+        self.convolution_ops = set(("Conv2DBiasAct", "Conv2D", "QuantizedConv2D",))
         self.depthwise_convolution_ops = set(
-            ("DepthwiseConv2dBiasAct", "DepthwiseConv2dNative", "QuantizedDepthwiseConv2D")
+            ("DepthwiseConv2dBiasAct", "DepthwiseConv2dNative", "QuantizedDepthwiseConv2D,")
         )
         self.transpose_convolution_ops = set(("Conv2DBackpropInput",))
-        self.max_pooling_ops = set(("QuantizedMaxPool", "MaxPool", "MaxPoolAct"))
-        self.avg_pooling_ops = set(("QuantizedAvgPool", "AvgPool", "AvgPoolAct"))
-        self.pooling_ops = self.max_pooling_ops | self.avg_pooling_ops
+        self.max_pooling_ops = set(("QuantizedMaxPool", "MaxPool", "MaxPoolAct",))
+        self.avg_pooling_ops = set(("QuantizedAvgPool", "AvgPool", "AvgPoolAct",))
+        self.pooling_ops = set(("ReduceSum",)) | self.max_pooling_ops | self.avg_pooling_ops
         self.resizing_ops = set(("ResizeBilinear",))
-        self.fc_vector_products = set(("QuantizedMatMul", "MatMul", "FullyConnectedAct"))
+        self.fc_vector_products = set(("QuantizedMatMul", "MatMul", "FullyConnectedAct",))
         self.mac_main_ops = (
             # convolutions
             self.convolution_ops
@@ -47,34 +48,56 @@
             # FC layers
             | self.fc_vector_products
             # RNN/LSTM/GRU
-            | set(("BlockLSTM"))
+            | set(("BlockLSTM",))
         )
-        self.unary_elem_wise_main_ops = set(("LeakyRelu", "Abs"))
-        self.binary_elem_wise_min_max_ops = set(("Minimum", "Maximum"))
+        self.unary_elem_wise_main_ops = set(("LeakyRelu", "Abs", "CLZ",))
+        self.binary_elem_wise_min_max_ops = set(("Minimum", "Maximum",))
         self.binary_elem_wise_add_mul_sub = set(
-            ("AddAct", "MulAct", "SubAct", "QuantizedAdd", "QuantizedSub", "QuantizedMul", "Mul", "Add", "Sub",)
+            (
+                "AddAct",
+                "MulAct",
+                "SubAct",
+                "QuantizedAdd",
+                "QuantizedSub",
+                "QuantizedMul",
+                "Mul",
+                "Add",
+                "Sub",
+                "SHL",
+                "SHR",
+            )
         )
         self.binary_elem_wise_main_ops = self.binary_elem_wise_min_max_ops | self.binary_elem_wise_add_mul_sub
         self.elem_wise_main_ops = self.binary_elem_wise_main_ops | self.unary_elem_wise_main_ops
         self.activation_ops = set(
-            ("QuantizedRelu", "QuantizedRelu1", "QuantizedRelu6", "Relu", "Relu6", "ReluN1To1", "Sigmoid", "Tanh")
+            (
+                "QuantizedRelu",
+                "QuantizedRelu1",
+                "QuantizedRelu6",
+                "Relu",
+                "Relu6",
+                "ReluN1To1",
+                "Sigmoid",
+                "Tanh",
+                "Softmax",
+            )
         )
         self.npu_post_ops = (
             # activation functions
             self.activation_ops
             # concatenation write direction
-            | set(("ConcatSliceWrite"))
+            | set(("ConcatSliceWrite",))
             # bias add and batch norm
-            | set(("QuantizedBiasAdd", "Requantize", "QuantizedBatchNorm", "BiasAdd", "FusedBatchNorm"))
+            | set(("QuantizedBiasAdd", "Requantize", "QuantizedBatchNorm", "BiasAdd", "FusedBatchNorm",))
             # Quantization
             | set(("Quantize",))
         )
-        self.split_ops = set(("Split", "SplitV", "StridedSlice", "Slice", "UnpackReshaped", "Unpack"))
-        self.concat_ops = set(("Concat", "ConcatV2", "QuantizedConcat", "ConcatTFLite", "PackReshaped", "Pack"))
+        self.split_ops = set(("Split", "SplitV", "StridedSlice", "Slice", "UnpackReshaped", "Unpack",))
+        self.concat_ops = set(("Concat", "ConcatV2", "QuantizedConcat", "ConcatTFLite", "PackReshaped", "Pack",))
         self.memory_only_ops = (
-            set(("Squeeze", "Reshape", "QuantizedReshape", "ExpandDims")) | self.concat_ops | self.split_ops
+            set(("Squeeze", "Reshape", "QuantizedReshape", "ExpandDims",)) | self.concat_ops | self.split_ops
         )
-        self.supported_fused_activations = set(("Relu", "Relu6", "ReluN1To1", "Tanh", "Sigmoid"))
+        self.supported_fused_activations = set(("Relu", "Relu6", "ReluN1To1", "Tanh", "Sigmoid", "LUT",))
         self.supported_operators = (
             self.npu_pre_ops | self.mac_main_ops | self.elem_wise_main_ops | self.npu_post_ops | self.memory_only_ops
         )
@@ -103,6 +126,7 @@
         self.supported_operator_restrictions.update(
             {op: self.check_quantization_restrictions for op in self.binary_elem_wise_min_max_ops}
         )
+        self.supported_operator_restrictions.update({op: self.check_activation_ops for op in self.activation_ops})
 
     def is_operator_supported(self, op):
         if op.type not in self.supported_operators:
@@ -127,7 +151,10 @@
         for t in tensors:
             if not (t.dtype.type & BaseType.Int):
                 return False
-            if t.element_size() > 2 and op.type not in ("Requantize") | self.binary_elem_wise_add_mul_sub:
+            if (
+                t.element_size() > 2
+                and op.type not in set(("Requantize", "ReduceSum", "CLZ",)) | self.binary_elem_wise_add_mul_sub
+            ):
                 return False
             # check size
             if any(dim > 65536 for dim in t.shape):
@@ -212,7 +239,9 @@
         # check data type
         ifm_tensor, _, _, ofm_tensor = op.get_ifm_ifm2_weights_ofm()
         if ifm_tensor.dtype != ofm_tensor.dtype:
-            return False
+            if op.type != "ReduceSum":
+                return False
+            # TODO: else check ReduceSum restrictions.
 
         # check batch size
         if ifm_tensor.shape[0] != 1:
@@ -309,9 +338,33 @@
 
     def check_quantization_restrictions(self, op):
         # makes sure IFM1, IFM2 and OFM quantization are equal for binary ops
-        if (len(op.inputs) == 2
-            and not op.inputs[0].quantization == op.inputs[1].quantization == op.outputs[0].quantization):
-            print("Warning: Input/output tensors with different quantization is unsupported for the", op.type,
-                  "operator")
+        if (
+            len(op.inputs) == 2
+            and not op.inputs[0].quantization == op.inputs[1].quantization == op.outputs[0].quantization
+        ):
+            print(
+                "Warning: Input/output tensors with different quantization is unsupported for the", op.type, "operator"
+            )
             return False
-        return True
\ No newline at end of file
+        return True
+
+    def check_activation_ops(self, op):
+        if op.type == "Softmax":
+            if not self.softmax_support:
+                return False
+
+            ifm_tensor = op.inputs[0]
+            ofm_tensor = op.outputs[0]
+
+            # check data type
+            if ifm_tensor.dtype != ofm_tensor.dtype:
+                return False
+
+            if ifm_tensor.dtype != DataType.int16:
+                return False  # TODO: Implement support for 8-bit Softmax
+
+            # check batch size
+            if len(ifm_tensor.shape) in (2, 4) and ifm_tensor.shape[0] != 1:
+                return False
+
+        return True