[MLBEDSW-2335] SoftMax int16

Added graph rewrite of Softmax for int16.

Change-Id: Id7885af6056a23e8b8362fb61ae94283251eb398
Signed-off-by: Fredrik Svedberg <fredrik.svedberg@arm.com>
diff --git a/ethosu/vela/register_command_stream_generator.py b/ethosu/vela/register_command_stream_generator.py
index 38b40ba..e0f114e 100644
--- a/ethosu/vela/register_command_stream_generator.py
+++ b/ethosu/vela/register_command_stream_generator.py
@@ -38,6 +38,7 @@
 from .ethos_u55_regs.ethos_u55_regs import cmd1
 from .ethos_u55_regs.ethos_u55_regs import elementwise_mode
 from .ethos_u55_regs.ethos_u55_regs import ifm_precision
+from .ethos_u55_regs.ethos_u55_regs import pooling_mode
 from .ethos_u55_regs.ethos_u55_regs import resampling_mode
 from .ethos_u55_regs.ethos_u55_regs import rounding
 from .high_level_command_stream import CommandType
@@ -52,6 +53,7 @@
 from .tensor import MemType
 from .tensor import TensorBlockTraversal
 from .tensor import TensorFormat
+from .tensor import TensorPurpose
 
 
 class RegisterMachine:
@@ -81,6 +83,7 @@
     WeightTensor = 0  # base address index for the Weight tensor
     ScratchTensor = 1  # base address index for the Scratch_tensor in the TensorArena
     ScratchFastTensor = 2  # base address for the Scratch_fast_tensor
+    Mem2Mem = (1 << 8) | (3 << 0)  # base address slot for memory 2 memory transfer
 
 
 # TODO: Replace with definitions from ethos_u55_regs
@@ -220,7 +223,9 @@
                         elif memory_accesses[prev_cmd].conflicts(curr_accesses):
                             is_dependency = True
                     else:
-                        if memory_accesses[prev_cmd].conflicts(curr_accesses):
+                        if memory_accesses[prev_cmd].conflicts(curr_accesses) or (
+                            prev_cmd.cmdtype == CommandType.DMA and prev_cmd.in_tensor.purpose == TensorPurpose.LUT
+                        ):
                             is_dependency = True
 
                     if is_dependency:
@@ -295,7 +300,12 @@
     # Note: NOT equivalent to the normal ifm block depth calculation since
     # it takes into account 'depthless' block operations by returning full
     # depth
-    if cmd.ps.npu_block_type in (NpuBlockType.ConvolutionDepthWise, NpuBlockType.Pooling, NpuBlockType.ElementWise):
+    if cmd.ps.npu_block_type in (
+        NpuBlockType.ConvolutionDepthWise,
+        NpuBlockType.Pooling,
+        NpuBlockType.ElementWise,
+        NpuBlockType.ReduceSum,
+    ):
         return cmd.ofm_box.get_size_shape()[-1]
 
     return arch.calc_ifm_block_depth(cmd.ifm_box.get_size_shape()[-1], cmd.ifm_tensor.dtype.bits)
@@ -306,6 +316,7 @@
         NpuBlockType.ConvolutionDepthWise,
         NpuBlockType.Pooling,
         NpuBlockType.ConvolutionMxN,
+        NpuBlockType.ReduceSum,
     ):
         return (0, 0)
 
@@ -353,6 +364,9 @@
         "Maximum": elementwise_mode.MAX.value,
         "LeakyRelu": elementwise_mode.LRELU.value,
         "Abs": elementwise_mode.ABS.value,
+        "CLZ": elementwise_mode.CLZ.value,
+        "SHR": elementwise_mode.SHR.value,
+        "SHL": elementwise_mode.SHL.value,
     }
 
     cmd_stream = []
@@ -407,7 +421,10 @@
 
             emit.cmd0_with_param(cmd0.NPU_SET_DMA0_SRC_REGION, base_ptr_idx_map[cmd.in_tensor.mem_type])
             emit.cmd1_with_offset(cmd1.NPU_SET_DMA0_SRC, src_addr)
-            emit.cmd0_with_param(cmd0.NPU_SET_DMA0_DST_REGION, base_ptr_idx_map[cmd.out_tensor.mem_type])
+            if cmd.out_tensor.purpose == TensorPurpose.LUT:
+                emit.cmd0_with_param(cmd0.NPU_SET_DMA0_DST_REGION, BasePointerIndex.Mem2Mem)
+            else:
+                emit.cmd0_with_param(cmd0.NPU_SET_DMA0_DST_REGION, base_ptr_idx_map[cmd.out_tensor.mem_type])
 
             emit.cmd1_with_offset(cmd1.NPU_SET_DMA0_DST, dst_addr)
             emit.cmd1_with_offset(cmd1.NPU_SET_DMA0_LEN, sz)
@@ -451,7 +468,9 @@
             shared_buffer = ps.shared_buffer
 
             if npu_block_type == NpuBlockType.ElementWise:
-                ifm2_broadcast = 0
+                ifm2_broadcast = (
+                    IFM2Broadcast.ReverseOperandOrder if primary_op.attrs.get("reverse_op_order", False) else 0
+                )
 
                 if cmd.ifm_tensor.shape == []:
                     # The scalar has to be the ifm2 tensor so switch the ifms
@@ -468,22 +487,26 @@
                     output_scale = cmd.ofm_tensor.quantization.scale_f32
                     use_global_scale = True
 
-                    if primary_op.type == "MulAct":
-                        if (faf == "Sigmoid") or (faf == "Tanh"):
-                            output_scale = 1 / 0x3000
+                    if output_scale is not None and faf in ("Sigmoid", "Tanh"):
+                        output_scale = 1 / 0x3000
 
-                        ofm_scale, shift = scaling.elementwise_mul_scale(input_scale, input2_scale, output_scale)
+                    if primary_op.type == "MulAct":
+                        if None in (input_scale, input2_scale, output_scale):
+                            ofm_scale = 1
+                            shift = 0
+                        else:
+                            ofm_scale, shift = scaling.elementwise_mul_scale(input_scale, input2_scale, output_scale)
                         emit.cmd1_with_offset(cmd1.NPU_SET_OFM_SCALE, ofm_scale, shift)
                     else:  # AddAct/SubAct
-                        if (faf == "Sigmoid") or (faf == "Tanh"):
-                            output_scale = 1 / 0x3000
-
                         # Force output scale same as the input scale for
                         # resizebiliner 1x1 that is converted to add
                         if "resizebilinear" in primary_op.attrs:
                             output_scale = input2_scale
 
-                        if input_scale == input2_scale:
+                        if None in (input_scale, input2_scale, output_scale):
+                            opa_scale = opb_scale = ofm_scale = 1
+                            opa_shift = shift = 0
+                        elif input_scale == input2_scale:
                             opa_scale, opb_scale, ofm_scale, shift = scaling.simplified_elementwise_add_sub_scale(
                                 input_scale, input2_scale, output_scale
                             )
@@ -512,7 +535,7 @@
                         emit.cmd1_with_offset(cmd1.NPU_SET_OPB_SCALE, opb_scale)
                         emit.cmd1_with_offset(cmd1.NPU_SET_OFM_SCALE, ofm_scale, shift)
 
-                if primary_op.type in set(("LeakyRelu", "Abs",)):
+                elif primary_op.type in set(("LeakyRelu", "Abs",)):
                     output_scale = cmd.ofm_tensor.quantization.scale_f32
                     use_global_scale = True
 
@@ -521,6 +544,8 @@
 
                     ofm_scale, shift = scaling.quantise_scale(output_scale)
                     emit.cmd1_with_offset(cmd1.NPU_SET_OFM_SCALE, ofm_scale, shift)
+                else:
+                    emit.cmd1_with_offset(cmd1.NPU_SET_OFM_SCALE, 1, 0)
 
                 # For elementwise set the required SHRAM to be equal to the total size of SHRAM
                 shram_required = arch.shram_total_banks
@@ -581,7 +606,12 @@
                 emit.cmd0_with_param(cmd0.NPU_SET_IFM_UPSCALE, resampling_mode.NONE)
 
             if npu_block_type in set(
-                (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise, NpuBlockType.Pooling)
+                (
+                    NpuBlockType.ConvolutionMxN,
+                    NpuBlockType.ConvolutionDepthWise,
+                    NpuBlockType.Pooling,
+                    NpuBlockType.ReduceSum,
+                )
             ):
                 # Set up padding
                 explicit_padding = list(primary_op.attrs["explicit_padding"])  # (top, left, bottom, right)
@@ -611,14 +641,17 @@
                 # set kernel y stride extension bits
                 stride |= (primary_op.attrs["strides"][1] - 1 >> 1) << 9
 
-                if npu_block_type == NpuBlockType.Pooling:
+                if npu_block_type in set((NpuBlockType.Pooling, NpuBlockType.ReduceSum)):
                     k_height, k_width = primary_op.attrs["ksize"][1:3]
                     emit.cmd0_with_param(cmd0.NPU_SET_KERNEL_HEIGHT_M1, k_height - 1)
                     emit.cmd0_with_param(cmd0.NPU_SET_KERNEL_WIDTH_M1, k_width - 1)
 
                     valid_padding = sum(explicit_padding) == 0
 
-                    if primary_op.type in set(("AvgPool", "AvgPoolAct", "ResizeBilinear")) and valid_padding:
+                    if (
+                        primary_op.type in set(("AvgPool", "AvgPoolAct", "ResizeBilinear", "ReduceSum"))
+                        and valid_padding
+                    ):
                         # For valid padding vela has to output scaling values
                         if faf == "Sigmoid" or faf == "Tanh":
                             rescale = 0x3000 * cmd.ifm_tensor.quantization.scale_f32
@@ -644,17 +677,24 @@
                             # k_height == k_width == 1 is allways true in this case
                             # Normally the scale is maximised, to get maximum precision, which means that
                             # if rescale != 1, scale need to consider the number of bits needed for rescaling
-                            rescale = cmd.ifm_tensor.quantization.scale_f32 / cmd.ofm_tensor.quantization.scale_f32
-                            rescale_bits = 0
-                            if k_height == k_width == 1:
-                                if fmf == "ConcatSliceWrite":
-                                    rounding_mode = rounding.NATURAL
-                                if rescale > 1:
-                                    rescale_bits = len(bin(round_up_to_int(rescale))) - 2 + 1
-                                elif rescale < 1:
-                                    rescale_bits = -(len(bin(round_up_to_int(1 / rescale))) - 2 - 1)
-                            scale, shift = scaling.quantise_pooling_scale(k_height * k_width, rescale_bits)
-                            scale = int(round_away_zero(scale * rescale))
+                            if None not in (
+                                cmd.ofm_tensor.quantization.scale_f32,
+                                cmd.ifm_tensor.quantization.scale_f32,
+                            ):
+                                rescale = cmd.ifm_tensor.quantization.scale_f32 / cmd.ofm_tensor.quantization.scale_f32
+                                rescale_bits = 0
+                                if k_height == k_width == 1:
+                                    if fmf == "ConcatSliceWrite":
+                                        rounding_mode = rounding.NATURAL
+                                    if rescale > 1:
+                                        rescale_bits = len(bin(round_up_to_int(rescale))) - 2 + 1
+                                    elif rescale < 1:
+                                        rescale_bits = -(len(bin(round_up_to_int(1 / rescale))) - 2 - 1)
+                                scale, shift = scaling.quantise_pooling_scale(k_height * k_width, rescale_bits)
+                                scale = int(round_away_zero(scale * rescale))
+                            else:
+                                scale = 1
+                                shift = 0
 
                         emit.cmd1_with_offset(cmd1.NPU_SET_OFM_SCALE, scale, shift)
                         # Valid-padded average pool should use the global scale from
@@ -798,6 +838,12 @@
                 else:
                     faf_min = quantise_float32(clamp_sigmoid(ifm_min), ofm_quant.scale_f32, ofm_quant.zero_point)
                     faf_max = quantise_float32(clamp_sigmoid(ifm_max), ofm_quant.scale_f32, ofm_quant.zero_point)
+            elif faf == "LUT":
+                lut_index = int(activation.LUT_START.value) + primary_op.attrs.get("lut_index", 0)
+                assert lut_index <= activation.LUT_END.value, "LUT index out of range."
+                emit.cmd0_with_param(cmd0.NPU_SET_ACTIVATION, lut_index)
+                faf_min = ofm_quant_qmin
+                faf_max = ofm_quant_qmax
             else:
                 raise Exception("Unsupported fused_activation_function = " + faf)
 
@@ -816,7 +862,7 @@
                 emit.cmd0_with_param(cmd0.NPU_SET_OFM_WIDTH_M1, 0)
             emit.cmd0_with_param(cmd0.NPU_SET_OFM_DEPTH_M1, out_shape[-1] - 1)
 
-            if npu_block_type in set((NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct)):
+            if npu_block_type in set((NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct, NpuBlockType.ReduceSum)):
                 in_shape = cmd.ifm_box.get_size_shape()
                 emit.cmd0_with_param(cmd0.NPU_SET_IFM_DEPTH_M1, in_shape[-1] - 1)
             else:
@@ -942,6 +988,8 @@
                 prec = 0
             elif ofm_dtype.size_in_bits() == 16:
                 prec = 2
+            elif ofm_dtype.size_in_bits() == 32:
+                prec = 4
             else:
                 assert 0
 
@@ -967,7 +1015,11 @@
             ifm_dtype = cmd.ifm_tensor.dtype
 
             assert weight_bits == 8, "Unsupported weight bit depth"
-            assert ifm_dtype.size_in_bits() in {8, 16}
+            assert (
+                ifm_dtype.size_in_bits() in {8, 16}
+                or ifm_dtype.size_in_bits() == 32
+                and npu_block_type in (NpuBlockType.ElementWise, NpuBlockType.ReduceSum)
+            ), "Unsupported ifm bit depth"
 
             if ifm_dtype.size_in_bits() == 8:
                 if ifm_dtype.type & BaseType.Signed:
@@ -979,6 +1031,8 @@
                     prec = ifm_precision.S16
                 else:
                     prec = ifm_precision.U16
+            elif ifm_dtype == DataType.int32:
+                prec = ifm_precision.S32
 
             ifm_prec = prec.value
             ifm2_prec = ifm_prec
@@ -1036,8 +1090,10 @@
                 # Vector product is implemented using a 1x1 convolution
                 emit.cmd_do_operation(cmd0.NPU_OP_CONV)
             elif npu_block_type == NpuBlockType.Pooling:
-                param = "Max" not in primary_op.type
+                param = pooling_mode.MAX.value if "Max" in primary_op.type else pooling_mode.AVERAGE.value
                 emit.cmd_do_operation(cmd0.NPU_OP_POOL, param=param)
+            elif npu_block_type == NpuBlockType.ReduceSum:
+                emit.cmd_do_operation(cmd0.NPU_OP_POOL, param=pooling_mode.REDUCE_SUM.value)
             elif npu_block_type == NpuBlockType.ElementWise:
                 param = elementwise_mode_map[primary_op.type]
                 emit.cmd_do_operation(cmd0.NPU_OP_ELEMENTWISE, param)