[MLBEDSW-2657] Softmax uint8/int8

Added graph rewrite of Softmax for uint8/int8.

Signed-off-by: Fredrik Svedberg <fredrik.svedberg@arm.com>
Change-Id: Iecdd5d2cd3156a601b3313debba4a3562e6be5d7
diff --git a/ethosu/vela/architecture_features.py b/ethosu/vela/architecture_features.py
index 2e53a69..5453f2c 100644
--- a/ethosu/vela/architecture_features.py
+++ b/ethosu/vela/architecture_features.py
@@ -94,7 +94,7 @@
     IFM16 = 1
     IFM8_Elementwise = 2
     IFM16_Elementwise = 3
-    IFM32_Elementwise = 4
+    IFM32 = 4
     Acc16 = 5
     Acc32 = 6
     Acc40 = 7
diff --git a/ethosu/vela/register_command_stream_generator.py b/ethosu/vela/register_command_stream_generator.py
index bdc3722..f8bee6c 100644
--- a/ethosu/vela/register_command_stream_generator.py
+++ b/ethosu/vela/register_command_stream_generator.py
@@ -499,6 +499,7 @@
                         if None in (input_scale, input2_scale, output_scale):
                             opa_scale = opb_scale = ofm_scale = 1
                             opa_shift = shift = 0
+                            ofm_scale, shift = primary_op.attrs.get("rescale", [1, 0])
                         elif input_scale == input2_scale:
                             opa_scale, opb_scale, ofm_scale, shift = scaling.simplified_elementwise_add_sub_scale(
                                 input_scale, input2_scale, output_scale
@@ -835,6 +836,8 @@
             elif faf == "LUT":
                 lut_index = int(activation.LUT_START.value) + primary_op.attrs.get("lut_index", -1)
                 assert activation.LUT_START.value <= lut_index <= activation.LUT_END.value, "LUT index out of range."
+                if cmd.ofm_tensor.dtype == DataType.int32:
+                    lut_index |= (3 << 12)  # Force I8 range
                 emit.cmd0_with_param(cmd0.NPU_SET_ACTIVATION, lut_index)
                 faf_min = ofm_quant_qmin
                 faf_max = ofm_quant_qmax
diff --git a/ethosu/vela/shared_buffer_allocation.py b/ethosu/vela/shared_buffer_allocation.py
index c968619..053377c 100644
--- a/ethosu/vela/shared_buffer_allocation.py
+++ b/ethosu/vela/shared_buffer_allocation.py
@@ -85,8 +85,8 @@
                 assert (self.use_ifm_element == SHRAMElements.IFM16) or (
                     self.use_ifm_element == SHRAMElements.IFM16_Elementwise
                 )
-            elif is_elementwise and self.ifm_bits == 32:
-                self.use_ifm_element = SHRAMElements.IFM32_Elementwise
+            elif is_elementwise or ps.npu_block_type == NpuBlockType.ReduceSum and self.ifm_bits == 32:
+                self.use_ifm_element = SHRAMElements.IFM32
             else:
                 assert self.ifm_bits == 8, "Unexpected IFM bitdepth"
 
diff --git a/ethosu/vela/softmax.py b/ethosu/vela/softmax.py
index 0a589eb..c67cc37 100644
--- a/ethosu/vela/softmax.py
+++ b/ethosu/vela/softmax.py
@@ -29,6 +29,77 @@
 class SoftMax:
     # Turn off black formatting for the LUT tables to keep them compact
     # fmt: off
+
+    EXP_LUT_U8 = [
+        0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
+        0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
+        0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
+        0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
+        0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
+        0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
+        0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
+        0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
+        0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
+        0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
+        0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
+        0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
+        0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
+        0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
+        0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
+        0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
+        0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
+        0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
+        0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
+        0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
+        0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
+        0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
+        0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
+        0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
+        0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
+        0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
+        0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
+        0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
+        0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
+        0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
+        0x00000291, 0x000006fa, 0x000012f6, 0x0000338b, 0x00008c1b, 0x00017cd8, 0x00040b3d, 0x000afe11,
+        0x001de16c, 0x00513949, 0x00dcca03, 0x02582ac2, 0x065f6c52, 0x1152aaf6, 0x2f16ad4c, 0x7fffffff
+    ]
+
+    EXP_LUT_I8 = [
+        0x000011c9, 0x000012b8, 0x000013b4, 0x000014bd, 0x000015d4, 0x000016fa, 0x0000182f, 0x00001975,
+        0x00001acb, 0x00001c34, 0x00001daf, 0x00001f3f, 0x000020e3, 0x0000229e, 0x00002470, 0x0000265a,
+        0x0000285e, 0x00002a7d, 0x00002cb9, 0x00002f13, 0x0000318c, 0x00003427, 0x000036e5, 0x000039c8,
+        0x00003cd1, 0x00004004, 0x00004361, 0x000046ec, 0x00004aa6, 0x00004e93, 0x000052b4, 0x0000570d,
+        0x00005ba1, 0x00006072, 0x00006583, 0x00006ada, 0x00007077, 0x00007661, 0x00007c9a, 0x00008327,
+        0x00008a0c, 0x0000914d, 0x000098f1, 0x0000a0fb, 0x0000a971, 0x0000b259, 0x0000bbb9, 0x0000c597,
+        0x0000cffa, 0x0000dae9, 0x0000e66b, 0x0000f288, 0x0000ff48, 0x00010cb3, 0x00011ad3, 0x000129b1,
+        0x00013957, 0x000149d0, 0x00015b26, 0x00016d65, 0x0001809b, 0x000194d2, 0x0001aa1a, 0x0001c080,
+        0x0001d814, 0x0001f0e4, 0x00020b03, 0x00022681, 0x00024371, 0x000261e7, 0x000281f7, 0x0002a3b5,
+        0x0002c73b, 0x0002ec9e, 0x000313f8, 0x00033d64, 0x000368fd, 0x000396e0, 0x0003c72e, 0x0003fa05,
+        0x00042f89, 0x000467dd, 0x0004a326, 0x0004e18e, 0x0005233d, 0x00056860, 0x0005b126, 0x0005fdbf,
+        0x00064e5f, 0x0006a33b, 0x0006fc8e, 0x00075a93, 0x0007bd89, 0x000825b3, 0x00089356, 0x000906bd,
+        0x00098034, 0x000a000f, 0x000a86a2, 0x000b1447, 0x000ba95f, 0x000c464d, 0x000ceb7c, 0x000d9959,
+        0x000e505a, 0x000f10f9, 0x000fdbb8, 0x0010b120, 0x001191c0, 0x00127e2f, 0x0013770b, 0x00147cfc,
+        0x001590b2, 0x0016b2e6, 0x0017e45c, 0x001925e1, 0x001a784c, 0x001bdc81, 0x001d536f, 0x001ede14,
+        0x00207d76, 0x002232af, 0x0023fee3, 0x0025e348, 0x0027e125, 0x0029f9ce, 0x002c2ead, 0x002e813e,
+        0x0030f30f, 0x003385c7, 0x00363b1e, 0x003914e9, 0x003c150f, 0x003f3d97, 0x004290a0, 0x00461065,
+        0x0049bf40, 0x004d9fac, 0x0051b444, 0x0055ffc2, 0x005a850e, 0x005f472f, 0x00644959, 0x00698eea,
+        0x006f1b6b, 0x0074f298, 0x007b185e, 0x008190dd, 0x00886073, 0x008f8bad, 0x00971761, 0x009f08a0,
+        0x00a764c0, 0x00b03163, 0x00b9746c, 0x00c3341a, 0x00cd76f8, 0x00d843eb, 0x00e3a23a, 0x00ef9981,
+        0x00fc31d0, 0x0109739d, 0x011767cf, 0x012617cd, 0x01358d6e, 0x0145d319, 0x0156f3be, 0x0168fadc,
+        0x017bf49d, 0x018fedb3, 0x01a4f391, 0x01bb1457, 0x01d25ede, 0x01eae2e1, 0x0204b0c5, 0x021fd9e9,
+        0x023c708e, 0x025a87f5, 0x027a343a, 0x029b8ac1, 0x02bea1ea, 0x02e39148, 0x030a71be, 0x03335d49,
+        0x035e6f88, 0x038bc564, 0x03bb7d53, 0x03edb776, 0x0422956d, 0x045a3add, 0x0494cd23, 0x04d27398,
+        0x051357c1, 0x0557a511, 0x059f8990, 0x05eb3585, 0x063adbc4, 0x068eb1f7, 0x06e6f042, 0x0743d212,
+        0x07a595d0, 0x080c7d1f, 0x0878cd5d, 0x08eacf11, 0x0962cefe, 0x09e11dc0, 0x0a661028, 0x0af1ffdf,
+        0x0b854a8e, 0x0c205363, 0x0cc38284, 0x0d6f4577, 0x0e241032, 0x0ee25ba2, 0x0faaa7e6, 0x107d7b92,
+        0x115b64b1, 0x1244f774, 0x133ad1b8, 0x143d9876, 0x154df988, 0x166cac69, 0x179a70c9, 0x18d81250,
+        0x1a266643, 0x1b864d38, 0x1cf8b430, 0x1e7e9307, 0x2018f0a9, 0x21c8e098, 0x238f850c, 0x256e1033,
+        0x2765c273, 0x2977ef40, 0x2ba5faa9, 0x2df15b73, 0x305b9d6b, 0x32e65e8a, 0x3593552c, 0x38644d67,
+        0x3b5b2b66, 0x3e79ee87, 0x41c2adcb, 0x45379f4e, 0x48db158a, 0x4caf81e6, 0x50b7797f, 0x54f5af16,
+        0x596cfe2f, 0x5e2066d0, 0x631310c8, 0x684852d8, 0x6dc3a909, 0x7388c421, 0x799b84b7, 0x7fffffff,
+    ]
+
     EXP_LUT = [
         0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002,
         0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002,
@@ -167,16 +238,263 @@
     def __init__(self, op):
         self.op = op
 
+    def generate_exp_table(self, beta, input_scale):
+        # TODO: Generate the exp table using the same math as the reference
+        return self.EXP_LUT_U8 if input_scale == 1.0 else self.EXP_LUT_I8
+
     def get_graph(self):
         ifm = self.op.inputs[0]
         ofm = self.op.outputs[0]
 
-        if ifm.dtype == DataType.int16 and ofm.dtype == DataType.int16:
+        if ifm.dtype in (DataType.uint8, DataType.int8) and ofm.dtype == ifm.dtype:
+            return self.get_graph_8bit(ifm, ofm)
+        elif ifm.dtype == DataType.int16 and ofm.dtype == DataType.int16:
             return self.get_graph_int16(ifm, ofm)
         else:
             self.op.run_on_npu = False
             return self.op
 
+    def get_graph_8bit(self, ifm, ofm):
+        exp_lut = self.generate_exp_table(self.op.attrs.get("beta", 1.0), ifm.quantization.scale_f32)
+        ifm = create_reshape_tensor(ifm, ifm.get_full_shape())
+        ofm = create_reshape_tensor(ofm, ofm.get_full_shape(), False)
+        no_scale_quant = ifm.quantization.clone()
+        no_scale_quant.scale_f32 = None
+        no_scale_quant.zero_point = 0
+        one_scale_quant = ifm.quantization.clone()
+        one_scale_quant.scale_f32 = 1.0
+        one_scale_quant.zero_point = 0
+        ifm.quantization.zero_point = 0
+
+        # PASS 0 - Depthwise Maxpool
+        maxpool_op = self.op.clone("_maxpool0")
+        maxpool_op.type = "MaxPool"
+        maxpool_h = ifm.shape[1] * ifm.shape[2]
+        maxpool_w = ifm.shape[3]
+        maxpool_ifm_shape = [1, maxpool_h, maxpool_w, 1]
+        maxpool_op.attrs["padding"] = b"VALID"
+        maxpool_op.attrs["stride_w"] = 1
+        maxpool_op.attrs["stride_h"] = 1
+        maxpool_op.attrs["filter_width"] = maxpool_w
+        maxpool_op.attrs["filter_height"] = 1
+        maxpool_op.attrs["strides"] = [1, maxpool_op.attrs["stride_h"], maxpool_op.attrs["stride_w"], 1]
+        maxpool_op.attrs["ksize"] = [1, maxpool_op.attrs["filter_height"], maxpool_op.attrs["filter_width"], 1]
+        maxpool_op.inputs = [create_reshape_tensor(ifm, maxpool_ifm_shape)]
+        ifm_max = Tensor([1, maxpool_h, 1, 1], ifm.dtype, maxpool_op.name + "_0")
+        ifm_max.quantization = no_scale_quant
+        maxpool_op.set_output_tensor(ifm_max)
+
+        # PASS 1 - Sub+LUT(exp)
+        sub_op = Operation("SubAct", self.op.name + "_sub1")
+        sub_op.add_input_tensor(ifm)
+        sub_op.add_input_tensor(ifm_max)
+        sub_op.set_activation_lut(
+            create_const_tensor(
+                sub_op.name + "_lut", [1, 1, 1, 256], DataType.int32, exp_lut, np.int32, TensorPurpose.LUT
+            )
+        )
+        ifm_exp = Tensor(ifm.shape, DataType.int32, sub_op.name + "_0")
+        ifm_exp.quantization = one_scale_quant.clone()
+        ifm_exp.quantization.zero_point = 127
+        ifm_exp.quantization.quant_min = -128
+        ifm_exp.quantization.quant_max = 127
+        sub_op.set_output_tensor(ifm_exp)
+
+        # PASS 2 - SHR
+        shr2_op = Operation("SHR", self.op.name + "_shr2")
+        shr2_op.add_input_tensor(ifm_exp)
+        shr2_op.add_input_tensor(
+            create_const_tensor(
+                shr2_op.name + "_const", [1, 1, 1, 1], DataType.int32, [12], np.int32, quantization=no_scale_quant
+            ),
+        )
+        rescaled_exp = Tensor(ifm.shape, ifm_exp.dtype, shr2_op.name + "_0")
+        rescaled_exp.quantization = no_scale_quant
+        shr2_op.set_output_tensor(rescaled_exp)
+
+        # PASS 3 - Reduce sum
+        reduce_sum_op = Operation("ReduceSum", self.op.name + "_reduce_sum3")
+        reduce_sum_op.attrs["padding"] = b"VALID"
+        reduce_sum_op.attrs["stride_w"] = 1
+        reduce_sum_op.attrs["stride_h"] = 1
+        reduce_sum_op.attrs["filter_width"] = 1
+        reduce_sum_op.attrs["filter_height"] = 1
+        reduce_sum_op.attrs["strides"] = [1, reduce_sum_op.attrs["stride_h"], reduce_sum_op.attrs["stride_w"], 1]
+        reduce_sum_op.attrs["ksize"] = [1, reduce_sum_op.attrs["filter_height"], reduce_sum_op.attrs["filter_width"], 1]
+        reduce_sum_op.add_input_tensor(rescaled_exp)
+
+        reduce_sum_shape = [1, rescaled_exp.shape[1], rescaled_exp.shape[2], 1]
+        sum_of_exp = Tensor(reduce_sum_shape, DataType.int32, reduce_sum_op.name + "_0")
+        sum_of_exp.quantization = no_scale_quant
+        reduce_sum_op.set_output_tensor(sum_of_exp)
+
+        # PASS 4 - CLZ
+        clz_op = Operation("CLZ", self.op.name + "_clz4")
+        clz_op.add_input_tensor(sum_of_exp)
+        headroom_plus_one = Tensor(reduce_sum_shape, DataType.int32, clz_op.name + "_0")
+        headroom_plus_one.quantization = no_scale_quant
+        clz_op.set_output_tensor(headroom_plus_one)
+
+        # PASS 5 - Sub
+        sub5_op = Operation("SubAct", self.op.name + "_sub5")
+        sub5_op.add_input_tensor(
+            create_const_tensor(
+                "headroom_offset_const", [1, 1, 1, 1], DataType.int32, [12 + 31 - 8], np.int32, quantization=no_scale_quant
+            ),
+        )
+        sub5_op.add_input_tensor(headroom_plus_one)
+        right_shift = Tensor(reduce_sum_shape, DataType.int32, sub5_op.name + "_0")
+        right_shift.quantization = no_scale_quant
+        sub5_op.set_output_tensor(right_shift)
+
+        # PASS 6 - Sub
+        one = create_const_tensor(
+            "one_const", [1, 1, 1, 1], DataType.int32, [1], np.int32, quantization=no_scale_quant
+        )
+        sub6_op = Operation("SubAct", self.op.name + "_sub6")
+        sub6_op.add_input_tensor(headroom_plus_one)
+        sub6_op.add_input_tensor(one)
+        headroom = Tensor(reduce_sum_shape, DataType.int32, sub6_op.name + "_0")
+        headroom.quantization = no_scale_quant
+        sub6_op.set_output_tensor(headroom)
+
+        # PASS 7 - SHL
+        shl7_op = Operation("SHL", self.op.name + "_shl7")
+        shl7_op.add_input_tensor(sum_of_exp)
+        shl7_op.add_input_tensor(headroom)
+        shifted_sum = Tensor(reduce_sum_shape, DataType.int32, shl7_op.name + "_0")
+        shifted_sum.quantization = no_scale_quant
+        shl7_op.set_output_tensor(shifted_sum)
+
+        # PASS 8 - Sub
+        sub8_op = Operation("SubAct", self.op.name + "_sub8")
+        sub8_op.add_input_tensor(shifted_sum)
+        sub8_op.add_input_tensor(
+            create_const_tensor(
+                "shifted_one_const", [1, 1, 1, 1], DataType.int32, [1 << 30], np.int32, quantization=no_scale_quant
+            ),
+        )
+        shifted_sum_minus_one = Tensor(reduce_sum_shape, DataType.int32, sub8_op.name + "_0")
+        shifted_sum_minus_one.quantization = no_scale_quant
+        sub8_op.set_output_tensor(shifted_sum_minus_one)
+
+        # PASS 9 - SHL
+        shl9_op = Operation("SHL", self.op.name + "_shl9")
+        shl9_op.add_input_tensor(shifted_sum_minus_one)
+        shl9_op.add_input_tensor(one)
+        shifted_sum_minus_one = Tensor(reduce_sum_shape, DataType.int32, shl9_op.name + "_0")
+        shifted_sum_minus_one.quantization = no_scale_quant
+        shl9_op.set_output_tensor(shifted_sum_minus_one)
+
+        # PASS 10 - Add
+        add10_op = Operation("AddAct", self.op.name + "_add10")
+        add10_op.add_input_tensor(
+            create_const_tensor(
+                "F0_one_const", [1, 1, 1, 1], DataType.int32, [(1 << 31) - 1], np.int32, quantization=no_scale_quant
+            ),
+        )
+        add10_op.add_input_tensor(shifted_sum_minus_one)
+        add10_op.attrs["rescale"] = [1, 1]
+        half_denominator = Tensor(reduce_sum_shape, DataType.int32, add10_op.name + "_0")
+        half_denominator.quantization = one_scale_quant
+        add10_op.set_output_tensor(half_denominator)
+
+        # PASS 11 - Multiply
+        mul11_op = Operation("MulAct", self.op.name + "_mul11")
+        mul11_op.add_input_tensor(half_denominator)
+        mul11_op.add_input_tensor(
+            create_const_tensor(
+                "neg_32_over_17_const", [1, 1, 1, 1], DataType.int32, [-1010580540], np.int32, quantization=one_scale_quant
+            ),
+        )
+        rescaled = Tensor(ifm_exp.shape, DataType.int32, mul11_op.name + "_0")
+        rescaled.quantization = one_scale_quant.clone()
+        rescaled.quantization.scale_f32 = 2.0
+        mul11_op.set_output_tensor(rescaled)
+
+        # PASS 12 - Add
+        add12_op = Operation("AddAct", self.op.name + "_add12")
+        add12_op.add_input_tensor(rescaled)
+        add12_op.add_input_tensor(
+            create_const_tensor(
+                "48_over_17_const", [1, 1, 1, 1], DataType.int32, [1515870810], np.int32, quantization=no_scale_quant
+            ),
+        )
+        rescale_w_offset = Tensor(reduce_sum_shape, DataType.int32, add12_op.name + "_0")
+        rescale_w_offset.quantization = one_scale_quant
+        add12_op.set_output_tensor(rescale_w_offset)
+
+        nr_x = rescale_w_offset
+        F2_one = create_const_tensor(
+            "F2_one_const", [1, 1, 1, 1], DataType.int32, [(1 << 29)], np.int32, quantization=no_scale_quant
+        )
+        two = create_const_tensor(
+            "two_const", [1, 1, 1, 1], DataType.int32, [2], np.int32, quantization=no_scale_quant
+        )
+        for i in range(3):
+            # PASS 13, 18, 23 - MUL
+            mul_op = Operation("MulAct", self.op.name + "_mul%d" % (13 + i * 5))
+            mul_op.add_input_tensor(nr_x)
+            mul_op.add_input_tensor(half_denominator)
+            half_denominator_times_x = Tensor(ifm_exp.shape, DataType.int32, mul_op.name + "_0")
+            half_denominator_times_x.quantization = one_scale_quant.clone()
+            half_denominator_times_x.quantization.scale_f32 = 2.0
+            mul_op.set_output_tensor(half_denominator_times_x)
+            # PASS 14, 19, 24 - SUB
+            sub_op = Operation("SubAct", self.op.name + "_sub%d" % (14 + i * 5))
+            sub_op.add_input_tensor(F2_one)
+            sub_op.add_input_tensor(half_denominator_times_x)
+            one_minus_half_denominator_times_x = Tensor(reduce_sum_shape, DataType.int32, sub_op.name + "_0")
+            one_minus_half_denominator_times_x.quantization = one_scale_quant
+            sub_op.set_output_tensor(one_minus_half_denominator_times_x)
+            # PASS 15, 20, 25 - MUL
+            mul_op = Operation("MulAct", self.op.name + "_mul%d" %+ (15 + i * 5))
+            mul_op.add_input_tensor(nr_x)
+            mul_op.add_input_tensor(one_minus_half_denominator_times_x)
+            to_rescale = Tensor(ifm_exp.shape, DataType.int32, mul_op.name + "_0")
+            to_rescale.quantization = one_scale_quant.clone()
+            to_rescale.quantization.scale_f32 = 2.0
+            mul_op.set_output_tensor(to_rescale)
+            # PASS 16, 21, 26 - SHL
+            shl_op = Operation("SHL", self.op.name + "_shl%d" % (16 + i * 5))
+            shl_op.add_input_tensor(to_rescale)
+            shl_op.add_input_tensor(two)
+            to_add = Tensor(reduce_sum_shape, DataType.int32, shl_op.name + "_0")
+            to_add.quantization = no_scale_quant
+            shl_op.set_output_tensor(to_add)
+            # PASS 17, 22, 27 - ADD
+            add_op = Operation("AddAct", self.op.name + "_add%d" % (17 + i * 5))
+            add_op.add_input_tensor(nr_x)
+            add_op.add_input_tensor(to_add)
+            nr_x = Tensor(reduce_sum_shape, DataType.int32, add_op.name + "_0")
+            nr_x.quantization = one_scale_quant
+            add_op.set_output_tensor(nr_x)
+
+        # PASS 28 - SHL
+        shl28_op = Operation("SHL", self.op.name + "_shl28")
+        shl28_op.add_input_tensor(nr_x)
+        shl28_op.add_input_tensor(one)
+        scale_factor = Tensor(reduce_sum_shape, DataType.int32, shl28_op.name + "_0")
+        scale_factor.quantization = one_scale_quant
+        shl28_op.set_output_tensor(scale_factor)
+
+        # PASS 29 - Multiply
+        mul_op = Operation("MulAct", self.op.name + "_mul29")
+        mul_op.add_input_tensor(ifm_exp)
+        mul_op.add_input_tensor(scale_factor)
+        scaled_exp = Tensor(ifm_exp.shape, DataType.int32, mul_op.name + "_0")
+        scaled_exp.quantization = one_scale_quant.clone()
+        scaled_exp.quantization.scale_f32 = 2.0
+        mul_op.set_output_tensor(scaled_exp)
+
+        # PASS 30 - SHR
+        shr30_op = Operation("SHR", self.op.name + "_shr30")
+        shr30_op.add_input_tensor(scaled_exp)
+        shr30_op.add_input_tensor(right_shift)
+        shr30_op.set_output_tensor(ofm)
+
+        return shr30_op
+
     def get_graph_int16(self, ifm, ofm):
         ifm = create_reshape_tensor(ifm, ifm.get_full_shape())
         ofm = create_reshape_tensor(ofm, ofm.get_full_shape(), False)
@@ -197,7 +515,7 @@
         maxpool_op.attrs["strides"] = [1, maxpool_op.attrs["stride_h"], maxpool_op.attrs["stride_w"], 1]
         maxpool_op.attrs["ksize"] = [1, maxpool_op.attrs["filter_height"], maxpool_op.attrs["filter_width"], 1]
         maxpool_op.inputs = [create_reshape_tensor(ifm, maxpool_ifm_shape)]
-        maxpool_ofm = Tensor([1, maxpool_h, 1, 1], DataType.int16, maxpool_op.name + "_0")
+        maxpool_ofm = Tensor([1, maxpool_h, 1, 1], ifm.dtype, maxpool_op.name + "_0")
         maxpool_ofm.quantization = no_scale_quant
         maxpool_op.set_output_tensor(maxpool_ofm)
 
@@ -219,7 +537,7 @@
         mul2_op.add_input_tensor(sub1_ofm)
         mul2_op.add_input_tensor(
             create_const_tensor(
-                mul2_op.name + "_const", [1, 1, 1, 1], DataType.int32, [mul2_scale], np.uint32, quantization=mul2_quant
+                mul2_op.name + "_const", [1, 1, 1, 1], DataType.int32, [mul2_scale], np.int32, quantization=mul2_quant
             ),
         )
         mul2_ofm = Tensor(ifm.shape, DataType.int32, mul2_op.name + "_0")
@@ -232,12 +550,12 @@
         add_op.add_input_tensor(mul2_ofm)
         add_op.add_input_tensor(
             create_const_tensor(
-                add_op.name + "_const", [1, 1, 1, 1], DataType.int32, [32767], np.uint32, quantization=no_scale_quant
+                add_op.name + "_const", [1, 1, 1, 1], DataType.int32, [32767], np.int32, quantization=no_scale_quant
             ),
         )
         add_op.set_activation_lut(
             create_const_tensor(
-                add_op.name + "_lut", [1, 1, 1, 512], DataType.int32, self.EXP_LUT, np.uint32, TensorPurpose.LUT
+                add_op.name + "_lut", [1, 1, 1, 512], DataType.int32, self.EXP_LUT, np.int32, TensorPurpose.LUT
             )
         )
         exp_ofm = Tensor(mul2_ofm.shape, DataType.int16, add_op.name + "_0")
@@ -271,7 +589,7 @@
         sub6_op = Operation("SubAct", self.op.name + "_sub6")
         sub6_op.add_input_tensor(
             create_const_tensor(
-                sub6_op.name + "_const", [1, 1, 1, 1], DataType.int32, [31], np.uint32, quantization=no_scale_quant
+                sub6_op.name + "_const", [1, 1, 1, 1], DataType.int32, [31], np.int32, quantization=no_scale_quant
             ),
         )
         sub6_op.add_input_tensor(headroom_plus_one)
@@ -283,11 +601,11 @@
         shl7_op = Operation("SHL", self.op.name + "_shl7")
         shl7_op.add_input_tensor(
             create_const_tensor(
-                shl7_op.name + "_const", [1, 1, 1, 1], DataType.int32, [1], np.uint32, quantization=no_scale_quant
+                shl7_op.name + "_const", [1, 1, 1, 1], DataType.int32, [1], np.int32, quantization=no_scale_quant
             ),
         )
         shl7_op.add_input_tensor(reciprocal_right_shift)
-        constant_one = Tensor(reduce_sum_shape, DataType.int32, shl7_op.name + "0")
+        constant_one = Tensor(reduce_sum_shape, DataType.int32, shl7_op.name + "_0")
         constant_one.quantization = no_scale_quant
         shl7_op.set_output_tensor(constant_one)
 
@@ -312,7 +630,7 @@
         shr10_op.add_input_tensor(shifted_sum_minus_one)
         shr10_op.add_input_tensor(
             create_const_tensor(
-                shr10_op.name + "_const", [1, 1, 1, 1], DataType.int32, [15], np.uint32, quantization=no_scale_quant
+                shr10_op.name + "_const", [1, 1, 1, 1], DataType.int32, [15], np.int32, quantization=no_scale_quant
             ),
         )
         shifted_sum_minus_one_16 = Tensor(reduce_sum_shape, DataType.int32, shr10_op.name + "_0")
@@ -324,7 +642,7 @@
         sub11_op.add_input_tensor(shifted_sum_minus_one_16)
         sub11_op.add_input_tensor(
             create_const_tensor(
-                sub11_op.name + "_const", [1, 1, 1, 1], DataType.int32, [32768], np.uint32, quantization=no_scale_quant
+                sub11_op.name + "_const", [1, 1, 1, 1], DataType.int32, [32768], np.int32, quantization=no_scale_quant
             ),
         )
         sub11_op.set_activation_lut(
@@ -333,7 +651,7 @@
                 [1, 1, 1, 512],
                 DataType.int32,
                 self.ONE_OVER_ONE_PLUS_X_LUT,
-                np.uint32,
+                np.int32,
                 TensorPurpose.LUT,
             )
         )
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py
index 43ba36f..fdf0c6b 100644
--- a/ethosu/vela/supported_operators.py
+++ b/ethosu/vela/supported_operators.py
@@ -52,6 +52,7 @@
         )
         self.unary_elem_wise_main_ops = set(("LeakyRelu", "Abs", "CLZ",))
         self.binary_elem_wise_min_max_ops = set(("Minimum", "Maximum",))
+        self.binary_elem_wise_shift_ops = set(("SHL", "SHR",))
         self.binary_elem_wise_add_mul_sub = set(
             (
                 "AddAct",
@@ -63,11 +64,9 @@
                 "Mul",
                 "Add",
                 "Sub",
-                "SHL",
-                "SHR",
             )
         )
-        self.binary_elem_wise_main_ops = self.binary_elem_wise_min_max_ops | self.binary_elem_wise_add_mul_sub
+        self.binary_elem_wise_main_ops = self.binary_elem_wise_min_max_ops | self.binary_elem_wise_add_mul_sub | self.binary_elem_wise_shift_ops
         self.elem_wise_main_ops = self.binary_elem_wise_main_ops | self.unary_elem_wise_main_ops
         self.activation_ops = set(
             (
@@ -153,7 +152,7 @@
                 return False
             if (
                 t.element_size() > 2
-                and op.type not in set(("Requantize", "ReduceSum", "CLZ",)) | self.binary_elem_wise_add_mul_sub
+                and op.type not in set(("Requantize", "ReduceSum", "CLZ",)) | self.binary_elem_wise_add_mul_sub | self.binary_elem_wise_shift_ops
             ):
                 return False
             # check size
@@ -311,6 +310,11 @@
                 ifm_tensor.dtype == ofm_tensor.dtype or ofm_tensor.dtype == DataType.int32
             ):
                 return False
+        elif op.type in self.binary_elem_wise_shift_ops | set(("CLZ")):
+            if ifm_tensor.dtype != DataType.int32 or ifm2_tensor.dtype != DataType.int32:
+                return False
+            if op.type in ("CLZ", "SHL") and ofm_tensor.dtype != DataType.int32:
+                return False
 
         # check batch size
         if len(ifm_tensor.shape) > 2 and ifm_tensor.shape[0] != 1:
@@ -365,8 +369,8 @@
             if ifm_tensor.dtype != ofm_tensor.dtype:
                 return False
 
-            if ifm_tensor.dtype != DataType.int16:
-                return False  # TODO: Implement support for 8-bit Softmax
+            if ifm_tensor.dtype not in (DataType.uint8, DataType.int8, DataType.int16):
+                return False
 
             # check batch size
             if len(ifm_tensor.shape) in (2, 4) and ifm_tensor.shape[0] != 1: