[MLBEDSW-2335] SoftMax int16

Added graph rewrite of Softmax for int16.

Change-Id: Id7885af6056a23e8b8362fb61ae94283251eb398
Signed-off-by: Fredrik Svedberg <fredrik.svedberg@arm.com>
diff --git a/ethosu/vela/architecture_features.py b/ethosu/vela/architecture_features.py
index 43b3210..822bc11 100644
--- a/ethosu/vela/architecture_features.py
+++ b/ethosu/vela/architecture_features.py
@@ -94,14 +94,15 @@
     IFM16 = 1
     IFM8_Elementwise = 2
     IFM16_Elementwise = 3
-    Acc16 = 4
-    Acc32 = 5
-    Acc40 = 6
+    IFM32_Elementwise = 4
+    Acc16 = 5
+    Acc32 = 6
+    Acc40 = 7
     Last = Acc40
-    BitSizes = np.array([8, 16, 8, 16, 16, 32, 40], np.int32)
+    BitSizes = np.array([8, 16, 8, 16, 32, 16, 32, 40], np.int32)
     ByteSizes = BitSizes // 8
-    PostAlign = np.array([8, 8, 8, 8, 1, 1, 1], np.int32)
-    PreAlign = np.array([1, 1, 1, 1, 8, 8, 8], np.int32)
+    PostAlign = np.array([8, 8, 8, 8, 8, 1, 1, 1], np.int32)
+    PreAlign = np.array([1, 1, 1, 1, 1, 8, 8, 8], np.int32)
 
 
 class SHRAMBlockConfig:
@@ -150,22 +151,22 @@
     )
     accelerator_configs = {
         Accelerator.Yoda_512: ArchitectureConfig(
-            256, 2, Block(2, 2, 8), Block(2, 2, 8), 48, [8, 8, 8, 8, 8, 16, 20], 8
+            256, 2, Block(2, 2, 8), Block(2, 2, 8), 48, [8, 8, 8, 8, 16, 8, 16, 20], 8
         ),
         Accelerator.Yoda_256: ArchitectureConfig(
-            256, 1, Block(2, 2, 8), Block(2, 2, 8), 48, [8, 8, 8, 8, 8, 16, 20], 8
+            256, 1, Block(2, 2, 8), Block(2, 2, 8), 48, [8, 8, 8, 8, 16, 8, 16, 20], 8
         ),
         Accelerator.Ethos_U55_256: ArchitectureConfig(
-            256, 1, Block(2, 2, 8), Block(2, 2, 8), 48, [8, 8, 8, 8, 8, 16, 20], 8
+            256, 1, Block(2, 2, 8), Block(2, 2, 8), 48, [8, 8, 8, 8, 16, 8, 16, 20], 8
         ),
         Accelerator.Ethos_U55_128: ArchitectureConfig(
-            128, 1, Block(2, 1, 8), Block(2, 2, 8), 24, [4, 4, 4, 4, 4, 8, 12], 4
+            128, 1, Block(2, 1, 8), Block(2, 2, 8), 24, [4, 4, 4, 4, 8, 4, 8, 12], 4
         ),
         Accelerator.Ethos_U55_64: ArchitectureConfig(
-            64, 1, Block(1, 1, 8), Block(1, 1, 8), 16, [2, 2, 2, 2, 4, 4, 8], 2
+            64, 1, Block(1, 1, 8), Block(1, 1, 8), 16, [2, 2, 2, 2, 4, 4, 4, 8], 2
         ),
         Accelerator.Ethos_U55_32: ArchitectureConfig(
-            32, 1, Block(1, 1, 4), Block(1, 1, 8), 16, [2, 2, 2, 2, 4, 4, 4], 1
+            32, 1, Block(1, 1, 4), Block(1, 1, 8), 16, [2, 2, 2, 2, 4, 4, 4, 4], 1
         ),
     }
 
@@ -182,6 +183,7 @@
         block_config_limit,
         global_memory_clock_scale,
         max_blockdep,
+        softmax_support,
     ):
         accelerator_config = accelerator_config.lower()
         self.vela_config = vela_config
@@ -262,11 +264,13 @@
             TensorPurpose.Unknown: MemArea.Unknown,
             TensorPurpose.Weights: self.permanent_storage_mem_area,
             TensorPurpose.FeatureMap: self.feature_map_storage_mem_area,
+            TensorPurpose.LUT: self.permanent_storage_mem_area,
         }
 
         self.tensor_storage_mem_type = {
             TensorPurpose.Weights: MemType.Permanent_NPU,
             TensorPurpose.FeatureMap: MemType.Scratch,
+            TensorPurpose.LUT: MemType.Scratch,
         }
 
         self.min_block_sizes = {
@@ -276,6 +280,7 @@
             NpuBlockType.Pooling: (dpu_min_height, dpu_min_width),
             NpuBlockType.ConvolutionDepthWise: (dpu_min_height, dpu_min_width),
             NpuBlockType.ElementWise: (1, 1),
+            NpuBlockType.ReduceSum: (dpu_min_height, dpu_min_width),
         }
 
         self.sub_kernel_limits = {
@@ -285,6 +290,7 @@
             NpuBlockType.Pooling: (8, 8),
             NpuBlockType.ConvolutionDepthWise: (8, 8),
             NpuBlockType.ElementWise: (1, 1),
+            NpuBlockType.ReduceSum: (8, 8),
         }
 
         # weights for scheduler search
@@ -317,7 +323,7 @@
         self.generate_block_config_map(Block(ifm_block_max.width, ifm_block_max.height, 128))
 
         # Setup supported operators and restriction checkers class
-        self.supported_operators = SupportedOperators()
+        self.supported_operators = SupportedOperators(softmax_support)
 
     # Calculate block configuration for ALL known IFM operations and
     # accumulator sizes. Consumers will need to select their preferred
@@ -358,10 +364,10 @@
                     self.block_config_map[key] = self.generate_block_config(w, h, c)
 
     def calc_ifm_block_depth(self, ifm_depth, ifm_bits):
-        assert ifm_bits == 8 or ifm_bits == 16
+        assert ifm_bits in (8, 16, 32)
         assert ifm_depth > 0
         ifm_depth = round_up(ifm_depth, self.ifm_ublock.depth)
-        max_block_depth = 32 if ifm_bits == 8 else 16
+        max_block_depth = 8 * 32 // ifm_bits
         return min(max_block_depth, ifm_depth)
 
     # Calculate the size of the IFM block given a depth, target OFM block and a kernel
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index 355b16f..9c6e1f5 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -27,6 +27,7 @@
 from .numeric_util import full_shape
 from .operation import NpuBlockType
 from .operation import Operation
+from .softmax import SoftMax
 from .tensor import QuantizationParameters
 from .tensor import Tensor
 
@@ -357,7 +358,7 @@
         if "Conv" in op.type:
             kernel_size = op.inputs[1].shape[:2]
             input_shape = op.inputs[0].shape
-        elif "Pool" in op.type or "ResizeBilinear" == op.type:
+        elif "Pool" in op.type or op.type in ("ResizeBilinear", "ReduceSum"):
             kernel_size = op.attrs["ksize"][1:3]
             input_shape = op.inputs[0].shape
         elif op.type == "ExtractImagePatches":
@@ -401,9 +402,10 @@
 )
 depthwise_op = set(("DepthwiseConv2dNative", "DepthwiseConv2dBiasAct",))
 pool_op = set(
-    ("AvgPool", "MaxPool", "QuantizedAvgPool", "QuantizedMaxPool", "AvgPoolAct", "MaxPoolAct", "ResizeBilinear",)
+    ("AvgPool", "MaxPool", "QuantizedAvgPool", "QuantizedMaxPool", "AvgPoolAct", "MaxPoolAct", "ResizeBilinear")
 )
-elementwise_op = set(("AddAct", "MulAct", "SubAct", "Maximum", "Minimum", "LeakyRelu", "Abs"))
+reduce_sum_ops = set(("ReduceSum",))
+elementwise_op = set(("AddAct", "MulAct", "SubAct", "Maximum", "Minimum", "LeakyRelu", "Abs", "CLZ", "SHL", "SHR"))
 binary_elementwise_op = set(("AddAct", "MulAct", "SubAct", "Maximum", "Minimum"))
 activation_ops = set(("Relu", "Relu6", "ReluN1To1", "Sigmoid", "Tanh"))
 memory_only_ops = set(("Reshape",))
@@ -437,6 +439,8 @@
         npu_block_type = NpuBlockType.Pooling
     elif op.type in elementwise_op:
         npu_block_type = NpuBlockType.ElementWise
+    elif op.type in reduce_sum_ops:
+        npu_block_type = NpuBlockType.ReduceSum
 
     op.attrs["npu_block_type"] = npu_block_type
     return op
@@ -573,6 +577,13 @@
     return op
 
 
+def convert_softmax(op, arch):
+    if op.type == "Softmax" and op.run_on_npu:
+        softmax = SoftMax(op)
+        op = softmax.get_graph()
+    return op
+
+
 def convert_mul_max_to_abs_or_lrelu(op, arch):
     r"""Whenever there is a subgraph with this topology:
 
@@ -671,6 +682,7 @@
         # then do any rewrites of supported operators
         convert_depthwise_to_conv,
         convert_conv_to_fc,
+        convert_softmax,
         fixup_fully_connected_input,
         fixup_pack_input,
         fixup_conv2d_backprop,
diff --git a/ethosu/vela/high_level_command_stream.py b/ethosu/vela/high_level_command_stream.py
index 0053e79..c669829 100644
--- a/ethosu/vela/high_level_command_stream.py
+++ b/ethosu/vela/high_level_command_stream.py
@@ -56,8 +56,10 @@
                 new_start_coord[idx] += split_offset[idx]
                 new_end_coord[idx] += split_offset[idx]
 
-        if split_offset is None and npu_block_type in set((NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct)):
-            # these types of operations do a "dot product" over the entire IFM
+        if split_offset is None and npu_block_type in set(
+            (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct, NpuBlockType.ReduceSum)
+        ):
+            # these types of operations do a "dot product" or sum over the entire IFM
             new_start_coord[-1] = 0
             new_end_coord[-1] = ifm_shape[-1]
 
diff --git a/ethosu/vela/high_level_command_stream_generator.py b/ethosu/vela/high_level_command_stream_generator.py
index 6aa88d8..2297a3b 100644
--- a/ethosu/vela/high_level_command_stream_generator.py
+++ b/ethosu/vela/high_level_command_stream_generator.py
@@ -143,18 +143,21 @@
                 if (
                     intermediate is not None
                     and intermediate.shape != []
-                    and intermediate.purpose == TensorPurpose.FeatureMap
+                    and intermediate.purpose in (TensorPurpose.FeatureMap, TensorPurpose.LUT)
                 ):
-                    intermediate_box, _, _ = ofm_box.transform_with_strides_and_skirt(
-                        strides,
-                        skirt,
-                        intermediate.shape,
-                        npu_block_type,
-                        concat_axis,
-                        concat_offset,
-                        split_offsets[0],
-                        upscaling,
-                    )
+                    if intermediate.purpose is TensorPurpose.FeatureMap:
+                        intermediate_box, _, _ = ofm_box.transform_with_strides_and_skirt(
+                            strides,
+                            skirt,
+                            intermediate.shape,
+                            npu_block_type,
+                            concat_axis,
+                            concat_offset,
+                            split_offsets[0],
+                            upscaling,
+                        )
+                    else:
+                        intermediate_box = Box([0] * len(intermediate.shape), list(intermediate.shape))
                     yield from dma_if_necessary(ps, intermediate_box, intermediate)
 
             weight_box = None
@@ -232,7 +235,7 @@
             ofm_box = Box(ofm_start, ofm_end)
 
             k_height = 1
-            if npu_block_type == NpuBlockType.Pooling:
+            if npu_block_type == set((NpuBlockType.Pooling, NpuBlockType.ReduceSum)):
                 if ps.primary_op is not None:
                     k_height = ps.primary_op.attrs["ksize"][1]
             else:
diff --git a/ethosu/vela/insert_dma.py b/ethosu/vela/insert_dma.py
index 4ea4621..76016f1 100644
--- a/ethosu/vela/insert_dma.py
+++ b/ethosu/vela/insert_dma.py
@@ -35,10 +35,11 @@
         if tens.mem_type not in (MemType.Scratch, MemType.Scratch_fast):
             # Tensor is in permanent storage
             # Only when permanent storage differs from fast storage, there is a point moving the data
-            if tens.mem_area in (MemArea.Dram, MemArea.OffChipFlash) and (
-                arch.permanent_storage_mem_area != arch.fast_storage_mem_area
-            ):
-                if tens.purpose == TensorPurpose.Weights or (
+            if (
+                tens.mem_area in (MemArea.Dram, MemArea.OffChipFlash)
+                and arch.permanent_storage_mem_area != arch.fast_storage_mem_area
+            ) or tens.purpose == TensorPurpose.LUT:
+                if tens.purpose in (TensorPurpose.Weights, TensorPurpose.LUT) or (
                     tens.purpose == TensorPurpose.FeatureMap and op.type in binary_elementwise_op and tens.shape != []
                 ):
                     only_vector_product_consumers = True
@@ -49,7 +50,8 @@
 
                     # Tensor products has no need for DMA, tensors are only read once and can be in flash.
                     # Other operations re-reads tensors, this is better done from SRAM.
-                    if not only_vector_product_consumers:
+                    # LUTs must be placed in the last 2 blocks of SHRAM.
+                    if not only_vector_product_consumers or tens.purpose == TensorPurpose.LUT:
                         # Insert a DMA command here, as well as a new tensor situated in SRAM of the same size.
                         new_tens = tens.clone_into_fast_storage(arch)
                         dma_cmd = Operation("DMA", tens.ops[0].name + "_dma")
@@ -59,6 +61,14 @@
                         dma_cmd.attrs["destination"] = new_tens.mem_area
                         dma_cmd.run_on_npu = True
                         new_tens.ops = [dma_cmd]
+                        if tens.purpose == TensorPurpose.LUT:
+                            # TODO: Add support more than one LUT at a time
+                            # Reserve last 2 blocks for LUT
+                            if arch.shram_reserved_unused_banks == 0:
+                                arch.shram_reserved_unused_banks = 2
+                                arch.shram_total_banks -= arch.shram_reserved_unused_banks
+                            # Place the LUT in the last 2 blocks of SHRAM
+                            new_tens.address = arch.shram_bank_size * arch.shram_total_banks
                         op.inputs[idx] = new_tens
     return op
 
diff --git a/ethosu/vela/mark_tensors.py b/ethosu/vela/mark_tensors.py
index 5f3a13f..b6b2f9f 100644
--- a/ethosu/vela/mark_tensors.py
+++ b/ethosu/vela/mark_tensors.py
@@ -96,6 +96,10 @@
                 "AvgPoolAct",
                 "MaxPoolAct",
                 "LeakyRelu",
+                "CLZ",
+                "SHL",
+                "SHR",
+                "ReduceSum",
             )
         ),
         all_fm,
@@ -252,7 +256,7 @@
 
         if tens.purpose == TensorPurpose.Unknown or tens.purpose == purpose:
             tens.purpose = purpose
-        else:
+        elif tens.purpose != TensorPurpose.LUT:
             assert 0, "Cannot resolve tensor purpose %s and %s for tensor %s" % (tens.purpose, purpose, tens)
         tens.mem_area = arch.tensor_storage_mem_area[tens.purpose]
         tens.mem_type = arch.tensor_storage_mem_type[tens.purpose]
@@ -332,7 +336,7 @@
     formats_for_tensor = {}
 
     def init_tens(tens):
-        if tens.purpose == TensorPurpose.FeatureMap:
+        if tens.purpose in (TensorPurpose.FeatureMap, TensorPurpose.LUT):
             fmt = arch.default_feature_map_format
         elif tens.purpose == TensorPurpose.Weights:
             fmt = arch.default_weight_format
diff --git a/ethosu/vela/nn_graph.py b/ethosu/vela/nn_graph.py
index 1024307..4a2855b 100644
--- a/ethosu/vela/nn_graph.py
+++ b/ethosu/vela/nn_graph.py
@@ -51,6 +51,7 @@
         self.ofm_tensor = None
         self.weight_tensor = None
         self.scale_tensor = None
+        self.lut_tensor = None
         self.name = name
         self.cascade = None
         self.placement = placement
@@ -85,6 +86,11 @@
             return None, None, None, None
         return self.primary_op.get_ifm_weights_biases_ofm()
 
+    def get_primary_op_lut(self):
+        if not self.primary_op:
+            return None
+        return self.primary_op.activation_lut
+
 
 class SchedulingStrategy(enum.Enum):
     Unknown = -1
diff --git a/ethosu/vela/npu_serialisation.py b/ethosu/vela/npu_serialisation.py
index 4b5a888..030503d 100644
--- a/ethosu/vela/npu_serialisation.py
+++ b/ethosu/vela/npu_serialisation.py
@@ -130,6 +130,8 @@
 
                     copy_compressed_values_to_memory_tensor(sg.flash_tensor, ps.scale_tensor)
 
+                if ps.lut_tensor is not None:
+                    copy_ifm_values_to_memory_tensor(sg.flash_tensor, ps.lut_tensor)
                 if ps.ifm_tensor is not None and ps.ifm_tensor.mem_type not in (MemType.Scratch, MemType.Scratch_fast):
                     copy_ifm_values_to_memory_tensor(sg.flash_tensor, ps.ifm_tensor)
                 if ps.ifm2_tensor is not None and (
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index 448d838..7134fd8 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -25,13 +25,25 @@
     Pooling = 3
     ConvolutionDepthWise = 4
     ElementWise = 5
+    ReduceSum = 6
 
 
 class Operation:
     """Class representing a Neural Network operation. Has a name, a type,
 input and output tensors, as well as an attribute dictionary."""
 
-    __slots__ = "type", "name", "op_index", "attrs", "inputs", "outputs", "flops", "scheduled_pass", "run_on_npu"
+    __slots__ = (
+        "type",
+        "name",
+        "op_index",
+        "attrs",
+        "inputs",
+        "outputs",
+        "flops",
+        "scheduled_pass",
+        "run_on_npu",
+        "activation_lut",
+    )
 
     def __init__(self, op_type, name):
         self.type = op_type
@@ -43,6 +55,7 @@
         self.run_on_npu = True
         self.scheduled_pass = None
         self.op_index = None  # input network operator index
+        self.activation_lut = None
 
     def clone(self, suffix="_clone"):
         res = Operation(self.type, self.name + suffix)
@@ -80,7 +93,7 @@
             elif self.type == "Conv2DBackpropInputSwitchedBias":
                 bias_idx = 3
 
-        elif npu_block_type == NpuBlockType.Pooling:
+        elif npu_block_type in (NpuBlockType.Pooling, NpuBlockType.ReduceSum):
             ifm_idx = 0
             ofm_idx = 0
         elif npu_block_type == NpuBlockType.VectorProduct:
@@ -102,8 +115,8 @@
             ifm2_idx = 1
             ofm_idx = 0
 
-            # LeakyRelu and Abs have a single IFM
-            if self.type in set(("LeakyRelu", "Abs")):
+            # LeakyRelu, Abs and CLZ have a single IFM
+            if self.type in set(("LeakyRelu", "Abs", "CLZ")):
                 ifm2_idx = -1
 
         elif self.type == "Conv2DBackpropInput":
@@ -292,3 +305,9 @@
             assert False
 
         return input_tens, outputs, axis, offset_start, offset_end
+
+    def set_activation_lut(self, lut_tensor):
+        lut_tensor.consumer_list.append(self)
+        self.attrs["fused_activation_function"] = "LUT"
+        self.activation_lut = lut_tensor
+        self.inputs.append(lut_tensor)
diff --git a/ethosu/vela/pass_packing.py b/ethosu/vela/pass_packing.py
index c14a70b..8fb95f0 100644
--- a/ethosu/vela/pass_packing.py
+++ b/ethosu/vela/pass_packing.py
@@ -66,6 +66,7 @@
         "MaxPool",
         "AvgPoolAct",
         "MaxPoolAct",
+        "ReduceSum",
         # deconvolution
         "ResizeBilinear",
     )
@@ -85,10 +86,12 @@
         "Sub",
         "Minimum",
         "Maximum",
+        "SHL",
+        "SHR",
     )
 )
 
-unary_elem_wise_main_ops = set(("LeakyRelu", "Abs"))  # Unary element-wise operations
+unary_elem_wise_main_ops = set(("LeakyRelu", "Abs", "CLZ",))  # Unary element-wise operations
 
 elem_wise_main_ops = binary_elem_wise_main_ops | unary_elem_wise_main_ops
 
@@ -417,13 +420,12 @@
             # Swap broadcast input if applicable
             broadcast_input_check(ps)
 
+            # If only 1 input, IFM and IFM2 will be the same tensor
             ps.ifm_tensor = ps.inputs[0]
+            ps.ifm2_tensor = ps.inputs[-1]
 
-            if len(ps.inputs) == 1:
-                # Only 1 input, IFM and IFM2 are the same tensor
-                ps.ifm2_tensor = ps.inputs[0]
-            else:
-                ps.ifm2_tensor = ps.inputs[1]
+            if len(ps.inputs) > 2:
+                ps.ifm_tensor = ps.inputs[-2]
         else:
             ps.ifm_tensor = ifm_tensor
             ps.ifm2_tensor = None
@@ -432,6 +434,7 @@
         assert ps.placement != PassPlacement.Npu or ps.ofm_tensor is not None
         ps.weight_tensor = ps.get_primary_op_ifm_weights()[1]
         ps.scale_tensor = ps.get_primary_op_ifm_weights_biases_ofm()[2]
+        ps.lut_tensor = ps.get_primary_op_lut()
 
         for op in ps.ops:
             op.scheduled_pass = ps
diff --git a/ethosu/vela/register_command_stream_generator.py b/ethosu/vela/register_command_stream_generator.py
index 38b40ba..e0f114e 100644
--- a/ethosu/vela/register_command_stream_generator.py
+++ b/ethosu/vela/register_command_stream_generator.py
@@ -38,6 +38,7 @@
 from .ethos_u55_regs.ethos_u55_regs import cmd1
 from .ethos_u55_regs.ethos_u55_regs import elementwise_mode
 from .ethos_u55_regs.ethos_u55_regs import ifm_precision
+from .ethos_u55_regs.ethos_u55_regs import pooling_mode
 from .ethos_u55_regs.ethos_u55_regs import resampling_mode
 from .ethos_u55_regs.ethos_u55_regs import rounding
 from .high_level_command_stream import CommandType
@@ -52,6 +53,7 @@
 from .tensor import MemType
 from .tensor import TensorBlockTraversal
 from .tensor import TensorFormat
+from .tensor import TensorPurpose
 
 
 class RegisterMachine:
@@ -81,6 +83,7 @@
     WeightTensor = 0  # base address index for the Weight tensor
     ScratchTensor = 1  # base address index for the Scratch_tensor in the TensorArena
     ScratchFastTensor = 2  # base address for the Scratch_fast_tensor
+    Mem2Mem = (1 << 8) | (3 << 0)  # base address slot for memory 2 memory transfer
 
 
 # TODO: Replace with definitions from ethos_u55_regs
@@ -220,7 +223,9 @@
                         elif memory_accesses[prev_cmd].conflicts(curr_accesses):
                             is_dependency = True
                     else:
-                        if memory_accesses[prev_cmd].conflicts(curr_accesses):
+                        if memory_accesses[prev_cmd].conflicts(curr_accesses) or (
+                            prev_cmd.cmdtype == CommandType.DMA and prev_cmd.in_tensor.purpose == TensorPurpose.LUT
+                        ):
                             is_dependency = True
 
                     if is_dependency:
@@ -295,7 +300,12 @@
     # Note: NOT equivalent to the normal ifm block depth calculation since
     # it takes into account 'depthless' block operations by returning full
     # depth
-    if cmd.ps.npu_block_type in (NpuBlockType.ConvolutionDepthWise, NpuBlockType.Pooling, NpuBlockType.ElementWise):
+    if cmd.ps.npu_block_type in (
+        NpuBlockType.ConvolutionDepthWise,
+        NpuBlockType.Pooling,
+        NpuBlockType.ElementWise,
+        NpuBlockType.ReduceSum,
+    ):
         return cmd.ofm_box.get_size_shape()[-1]
 
     return arch.calc_ifm_block_depth(cmd.ifm_box.get_size_shape()[-1], cmd.ifm_tensor.dtype.bits)
@@ -306,6 +316,7 @@
         NpuBlockType.ConvolutionDepthWise,
         NpuBlockType.Pooling,
         NpuBlockType.ConvolutionMxN,
+        NpuBlockType.ReduceSum,
     ):
         return (0, 0)
 
@@ -353,6 +364,9 @@
         "Maximum": elementwise_mode.MAX.value,
         "LeakyRelu": elementwise_mode.LRELU.value,
         "Abs": elementwise_mode.ABS.value,
+        "CLZ": elementwise_mode.CLZ.value,
+        "SHR": elementwise_mode.SHR.value,
+        "SHL": elementwise_mode.SHL.value,
     }
 
     cmd_stream = []
@@ -407,7 +421,10 @@
 
             emit.cmd0_with_param(cmd0.NPU_SET_DMA0_SRC_REGION, base_ptr_idx_map[cmd.in_tensor.mem_type])
             emit.cmd1_with_offset(cmd1.NPU_SET_DMA0_SRC, src_addr)
-            emit.cmd0_with_param(cmd0.NPU_SET_DMA0_DST_REGION, base_ptr_idx_map[cmd.out_tensor.mem_type])
+            if cmd.out_tensor.purpose == TensorPurpose.LUT:
+                emit.cmd0_with_param(cmd0.NPU_SET_DMA0_DST_REGION, BasePointerIndex.Mem2Mem)
+            else:
+                emit.cmd0_with_param(cmd0.NPU_SET_DMA0_DST_REGION, base_ptr_idx_map[cmd.out_tensor.mem_type])
 
             emit.cmd1_with_offset(cmd1.NPU_SET_DMA0_DST, dst_addr)
             emit.cmd1_with_offset(cmd1.NPU_SET_DMA0_LEN, sz)
@@ -451,7 +468,9 @@
             shared_buffer = ps.shared_buffer
 
             if npu_block_type == NpuBlockType.ElementWise:
-                ifm2_broadcast = 0
+                ifm2_broadcast = (
+                    IFM2Broadcast.ReverseOperandOrder if primary_op.attrs.get("reverse_op_order", False) else 0
+                )
 
                 if cmd.ifm_tensor.shape == []:
                     # The scalar has to be the ifm2 tensor so switch the ifms
@@ -468,22 +487,26 @@
                     output_scale = cmd.ofm_tensor.quantization.scale_f32
                     use_global_scale = True
 
-                    if primary_op.type == "MulAct":
-                        if (faf == "Sigmoid") or (faf == "Tanh"):
-                            output_scale = 1 / 0x3000
+                    if output_scale is not None and faf in ("Sigmoid", "Tanh"):
+                        output_scale = 1 / 0x3000
 
-                        ofm_scale, shift = scaling.elementwise_mul_scale(input_scale, input2_scale, output_scale)
+                    if primary_op.type == "MulAct":
+                        if None in (input_scale, input2_scale, output_scale):
+                            ofm_scale = 1
+                            shift = 0
+                        else:
+                            ofm_scale, shift = scaling.elementwise_mul_scale(input_scale, input2_scale, output_scale)
                         emit.cmd1_with_offset(cmd1.NPU_SET_OFM_SCALE, ofm_scale, shift)
                     else:  # AddAct/SubAct
-                        if (faf == "Sigmoid") or (faf == "Tanh"):
-                            output_scale = 1 / 0x3000
-
                         # Force output scale same as the input scale for
                         # resizebiliner 1x1 that is converted to add
                         if "resizebilinear" in primary_op.attrs:
                             output_scale = input2_scale
 
-                        if input_scale == input2_scale:
+                        if None in (input_scale, input2_scale, output_scale):
+                            opa_scale = opb_scale = ofm_scale = 1
+                            opa_shift = shift = 0
+                        elif input_scale == input2_scale:
                             opa_scale, opb_scale, ofm_scale, shift = scaling.simplified_elementwise_add_sub_scale(
                                 input_scale, input2_scale, output_scale
                             )
@@ -512,7 +535,7 @@
                         emit.cmd1_with_offset(cmd1.NPU_SET_OPB_SCALE, opb_scale)
                         emit.cmd1_with_offset(cmd1.NPU_SET_OFM_SCALE, ofm_scale, shift)
 
-                if primary_op.type in set(("LeakyRelu", "Abs",)):
+                elif primary_op.type in set(("LeakyRelu", "Abs",)):
                     output_scale = cmd.ofm_tensor.quantization.scale_f32
                     use_global_scale = True
 
@@ -521,6 +544,8 @@
 
                     ofm_scale, shift = scaling.quantise_scale(output_scale)
                     emit.cmd1_with_offset(cmd1.NPU_SET_OFM_SCALE, ofm_scale, shift)
+                else:
+                    emit.cmd1_with_offset(cmd1.NPU_SET_OFM_SCALE, 1, 0)
 
                 # For elementwise set the required SHRAM to be equal to the total size of SHRAM
                 shram_required = arch.shram_total_banks
@@ -581,7 +606,12 @@
                 emit.cmd0_with_param(cmd0.NPU_SET_IFM_UPSCALE, resampling_mode.NONE)
 
             if npu_block_type in set(
-                (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise, NpuBlockType.Pooling)
+                (
+                    NpuBlockType.ConvolutionMxN,
+                    NpuBlockType.ConvolutionDepthWise,
+                    NpuBlockType.Pooling,
+                    NpuBlockType.ReduceSum,
+                )
             ):
                 # Set up padding
                 explicit_padding = list(primary_op.attrs["explicit_padding"])  # (top, left, bottom, right)
@@ -611,14 +641,17 @@
                 # set kernel y stride extension bits
                 stride |= (primary_op.attrs["strides"][1] - 1 >> 1) << 9
 
-                if npu_block_type == NpuBlockType.Pooling:
+                if npu_block_type in set((NpuBlockType.Pooling, NpuBlockType.ReduceSum)):
                     k_height, k_width = primary_op.attrs["ksize"][1:3]
                     emit.cmd0_with_param(cmd0.NPU_SET_KERNEL_HEIGHT_M1, k_height - 1)
                     emit.cmd0_with_param(cmd0.NPU_SET_KERNEL_WIDTH_M1, k_width - 1)
 
                     valid_padding = sum(explicit_padding) == 0
 
-                    if primary_op.type in set(("AvgPool", "AvgPoolAct", "ResizeBilinear")) and valid_padding:
+                    if (
+                        primary_op.type in set(("AvgPool", "AvgPoolAct", "ResizeBilinear", "ReduceSum"))
+                        and valid_padding
+                    ):
                         # For valid padding vela has to output scaling values
                         if faf == "Sigmoid" or faf == "Tanh":
                             rescale = 0x3000 * cmd.ifm_tensor.quantization.scale_f32
@@ -644,17 +677,24 @@
                             # k_height == k_width == 1 is allways true in this case
                             # Normally the scale is maximised, to get maximum precision, which means that
                             # if rescale != 1, scale need to consider the number of bits needed for rescaling
-                            rescale = cmd.ifm_tensor.quantization.scale_f32 / cmd.ofm_tensor.quantization.scale_f32
-                            rescale_bits = 0
-                            if k_height == k_width == 1:
-                                if fmf == "ConcatSliceWrite":
-                                    rounding_mode = rounding.NATURAL
-                                if rescale > 1:
-                                    rescale_bits = len(bin(round_up_to_int(rescale))) - 2 + 1
-                                elif rescale < 1:
-                                    rescale_bits = -(len(bin(round_up_to_int(1 / rescale))) - 2 - 1)
-                            scale, shift = scaling.quantise_pooling_scale(k_height * k_width, rescale_bits)
-                            scale = int(round_away_zero(scale * rescale))
+                            if None not in (
+                                cmd.ofm_tensor.quantization.scale_f32,
+                                cmd.ifm_tensor.quantization.scale_f32,
+                            ):
+                                rescale = cmd.ifm_tensor.quantization.scale_f32 / cmd.ofm_tensor.quantization.scale_f32
+                                rescale_bits = 0
+                                if k_height == k_width == 1:
+                                    if fmf == "ConcatSliceWrite":
+                                        rounding_mode = rounding.NATURAL
+                                    if rescale > 1:
+                                        rescale_bits = len(bin(round_up_to_int(rescale))) - 2 + 1
+                                    elif rescale < 1:
+                                        rescale_bits = -(len(bin(round_up_to_int(1 / rescale))) - 2 - 1)
+                                scale, shift = scaling.quantise_pooling_scale(k_height * k_width, rescale_bits)
+                                scale = int(round_away_zero(scale * rescale))
+                            else:
+                                scale = 1
+                                shift = 0
 
                         emit.cmd1_with_offset(cmd1.NPU_SET_OFM_SCALE, scale, shift)
                         # Valid-padded average pool should use the global scale from
@@ -798,6 +838,12 @@
                 else:
                     faf_min = quantise_float32(clamp_sigmoid(ifm_min), ofm_quant.scale_f32, ofm_quant.zero_point)
                     faf_max = quantise_float32(clamp_sigmoid(ifm_max), ofm_quant.scale_f32, ofm_quant.zero_point)
+            elif faf == "LUT":
+                lut_index = int(activation.LUT_START.value) + primary_op.attrs.get("lut_index", 0)
+                assert lut_index <= activation.LUT_END.value, "LUT index out of range."
+                emit.cmd0_with_param(cmd0.NPU_SET_ACTIVATION, lut_index)
+                faf_min = ofm_quant_qmin
+                faf_max = ofm_quant_qmax
             else:
                 raise Exception("Unsupported fused_activation_function = " + faf)
 
@@ -816,7 +862,7 @@
                 emit.cmd0_with_param(cmd0.NPU_SET_OFM_WIDTH_M1, 0)
             emit.cmd0_with_param(cmd0.NPU_SET_OFM_DEPTH_M1, out_shape[-1] - 1)
 
-            if npu_block_type in set((NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct)):
+            if npu_block_type in set((NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct, NpuBlockType.ReduceSum)):
                 in_shape = cmd.ifm_box.get_size_shape()
                 emit.cmd0_with_param(cmd0.NPU_SET_IFM_DEPTH_M1, in_shape[-1] - 1)
             else:
@@ -942,6 +988,8 @@
                 prec = 0
             elif ofm_dtype.size_in_bits() == 16:
                 prec = 2
+            elif ofm_dtype.size_in_bits() == 32:
+                prec = 4
             else:
                 assert 0
 
@@ -967,7 +1015,11 @@
             ifm_dtype = cmd.ifm_tensor.dtype
 
             assert weight_bits == 8, "Unsupported weight bit depth"
-            assert ifm_dtype.size_in_bits() in {8, 16}
+            assert (
+                ifm_dtype.size_in_bits() in {8, 16}
+                or ifm_dtype.size_in_bits() == 32
+                and npu_block_type in (NpuBlockType.ElementWise, NpuBlockType.ReduceSum)
+            ), "Unsupported ifm bit depth"
 
             if ifm_dtype.size_in_bits() == 8:
                 if ifm_dtype.type & BaseType.Signed:
@@ -979,6 +1031,8 @@
                     prec = ifm_precision.S16
                 else:
                     prec = ifm_precision.U16
+            elif ifm_dtype == DataType.int32:
+                prec = ifm_precision.S32
 
             ifm_prec = prec.value
             ifm2_prec = ifm_prec
@@ -1036,8 +1090,10 @@
                 # Vector product is implemented using a 1x1 convolution
                 emit.cmd_do_operation(cmd0.NPU_OP_CONV)
             elif npu_block_type == NpuBlockType.Pooling:
-                param = "Max" not in primary_op.type
+                param = pooling_mode.MAX.value if "Max" in primary_op.type else pooling_mode.AVERAGE.value
                 emit.cmd_do_operation(cmd0.NPU_OP_POOL, param=param)
+            elif npu_block_type == NpuBlockType.ReduceSum:
+                emit.cmd_do_operation(cmd0.NPU_OP_POOL, param=pooling_mode.REDUCE_SUM.value)
             elif npu_block_type == NpuBlockType.ElementWise:
                 param = elementwise_mode_map[primary_op.type]
                 emit.cmd_do_operation(cmd0.NPU_OP_ELEMENTWISE, param)
diff --git a/ethosu/vela/shared_buffer_allocation.py b/ethosu/vela/shared_buffer_allocation.py
index 07637f3..7268d9f 100644
--- a/ethosu/vela/shared_buffer_allocation.py
+++ b/ethosu/vela/shared_buffer_allocation.py
@@ -82,6 +82,8 @@
                 assert (self.use_ifm_element == SHRAMElements.IFM16) or (
                     self.use_ifm_element == SHRAMElements.IFM16_Elementwise
                 )
+            elif is_elementwise and self.ifm_bits == 32:
+                self.use_ifm_element = SHRAMElements.IFM32_Elementwise
             else:
                 assert self.ifm_bits == 8, "Unexpected IFM bitdepth"
 
@@ -168,7 +170,7 @@
     if arch.override_block_config:
         config = alloc.try_block(arch.override_block_config)
         if config is None:
-            raise VelaError("Block config override '{0}' cannot be allocated".format(arch.override_block_config) )
+            raise VelaError("Block config override '{0}' cannot be allocated".format(arch.override_block_config))
         return [config]
 
     # Constrain the search space if the OFM is smaller than the max block size
diff --git a/ethosu/vela/softmax.py b/ethosu/vela/softmax.py
new file mode 100644
index 0000000..000c78e
--- /dev/null
+++ b/ethosu/vela/softmax.py
@@ -0,0 +1,421 @@
+# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
+#
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the License); you may
+# not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an AS IS BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Description:
+# Contains SoftMax
+import numpy as np
+
+from . import scaling
+from .data_type import DataType
+from .operation import Operation
+from .tensor import Tensor
+from .tensor import TensorPurpose
+
+
+class TensorUtil:
+    # TODO: Move these functions to Tensor/Operation classes
+    @staticmethod
+    def create_const_tensor(
+        name, shape, dtype, values, value_dtype=None, purpose=TensorPurpose.Unknown, quantization=None
+    ):
+        const_op = Operation("Const", name)
+        const_tensor = Tensor(shape, dtype, name + "_0")
+        const_tensor.purpose = purpose
+        const_tensor.quantization = quantization
+        const_tensor.values = np.array(values, dtype=value_dtype)
+        const_tensor.quant_values = np.frombuffer(const_tensor.values.tobytes(), dtype=np.uint8)
+        const_tensor.ops.append(const_op)
+        const_op.outputs.append(const_tensor)
+        return const_tensor
+
+    @staticmethod
+    def add_ifm_tensor(op, tens):
+        op.inputs.append(tens)
+        tens.consumer_list.append(op)
+
+    @staticmethod
+    def set_ofm_tensor(op, tens):
+        tens.ops = [op]
+        op.outputs = [tens]
+
+    @staticmethod
+    def reshape(tens, shape, ifm_reshape=True):
+        if shape == tens.shape:
+            return tens
+        name = tens.name + "_reshape"
+        reshape_op = Operation("Reshape", name)
+        reshape_op.attrs["new_shape"] = shape
+        reshape_ifm = tens
+        reshape_ofm = tens.clone("_reshaped")
+        reshape_ofm.shape = reshape_ofm.storage_shape = reshape_ofm.bandwidth_shape = shape
+        if not ifm_reshape:
+            reshape_ifm, reshape_ofm = reshape_ofm, reshape_ifm
+        reshape_op.inputs = [reshape_ifm, TensorUtil.create_const_tensor(name + "_shape", [1], DataType.int32, shape)]
+        TensorUtil.set_ofm_tensor(reshape_op, reshape_ofm)
+        return reshape_ofm if ifm_reshape else reshape_ifm
+
+    @staticmethod
+    def get_full_shape(shape):
+        d = len(shape)
+        if d in (1, 3):
+            return [1] * (4 - d) + shape
+        elif d == 2:
+            return [shape[0], 1, 1, shape[1]]
+        else:
+            return shape
+
+
+class SoftMax:
+    # Turn off black formatting for the LUT tables to keep them compact
+    # fmt: off
+    EXP_LUT = [
+        0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002,
+        0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002,
+        0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002,
+        0x00000002, 0x00000002, 0x00010002, 0x00000003, 0x00000003, 0x00000003, 0x00000003, 0x00000003,
+        0x00000003, 0x00000003, 0x00000003, 0x00000003, 0x00000003, 0x00000003, 0x00000003, 0x00000003,
+        0x00000003, 0x00000003, 0x00000003, 0x00010003, 0x00000004, 0x00000004, 0x00000004, 0x00000004,
+        0x00000004, 0x00000004, 0x00000004, 0x00000004, 0x00000004, 0x00000004, 0x00000004, 0x00000004,
+        0x00010004, 0x00000005, 0x00000005, 0x00000005, 0x00000005, 0x00000005, 0x00000005, 0x00000005,
+        0x00000005, 0x00000005, 0x00010005, 0x00000006, 0x00000006, 0x00000006, 0x00000006, 0x00000006,
+        0x00000006, 0x00000006, 0x00010006, 0x00000007, 0x00000007, 0x00000007, 0x00000007, 0x00000007,
+        0x00000007, 0x00000007, 0x00010007, 0x00000008, 0x00000008, 0x00000008, 0x00000008, 0x00000008,
+        0x00010008, 0x00000009, 0x00000009, 0x00000009, 0x00000009, 0x00000009, 0x00010009, 0x0000000a,
+        0x0000000a, 0x0000000a, 0x0000000a, 0x0001000a, 0x0000000b, 0x0000000b, 0x0000000b, 0x0000000b,
+        0x0001000b, 0x0000000c, 0x0000000c, 0x0000000c, 0x0001000c, 0x0000000d, 0x0000000d, 0x0000000d,
+        0x0001000d, 0x0000000e, 0x0000000e, 0x0000000e, 0x0001000e, 0x0000000f, 0x0000000f, 0x0001000f,
+        0x00000010, 0x00000010, 0x00010010, 0x00000011, 0x00000011, 0x00010011, 0x00000012, 0x00000012,
+        0x00010012, 0x00000013, 0x00000013, 0x00010013, 0x00000014, 0x00010014, 0x00000015, 0x00000015,
+        0x00010015, 0x00000016, 0x00010016, 0x00000017, 0x00010017, 0x00000018, 0x00010018, 0x00000019,
+        0x00010019, 0x0000001a, 0x0001001a, 0x0000001b, 0x0001001b, 0x0000001c, 0x0001001c, 0x0000001d,
+        0x0001001d, 0x0000001e, 0x0001001e, 0x0001001f, 0x00000020, 0x00010020, 0x00010021, 0x00000022,
+        0x00010022, 0x00010023, 0x00000024, 0x00010024, 0x00000025, 0x00010025, 0x00010026, 0x00010027,
+        0x00000028, 0x00020028, 0x0000002a, 0x0001002a, 0x0001002b, 0x0001002c, 0x0000002d, 0x0001002d,
+        0x0001002e, 0x0001002f, 0x00010030, 0x00010031, 0x00010032, 0x00010033, 0x00010034, 0x00010035,
+        0x00010036, 0x00010037, 0x00010038, 0x00020039, 0x0001003b, 0x0000003c, 0x0002003c, 0x0001003e,
+        0x0002003f, 0x00000041, 0x00020041, 0x00010043, 0x00010044, 0x00020045, 0x00020047, 0x00010049,
+        0x0001004a, 0x0002004b, 0x0001004d, 0x0002004e, 0x00010050, 0x00020051, 0x00020053, 0x00010055,
+        0x00020056, 0x00020058, 0x0002005a, 0x0001005c, 0x0002005d, 0x0002005f, 0x00020061, 0x00020063,
+        0x00020065, 0x00020067, 0x00020069, 0x0002006b, 0x0003006d, 0x00020070, 0x00020072, 0x00020074,
+        0x00030076, 0x00020079, 0x0003007b, 0x0002007e, 0x00030080, 0x00020083, 0x00020085, 0x00040087,
+        0x0002008b, 0x0003008d, 0x00030090, 0x00020093, 0x00030095, 0x00030098, 0x0003009b, 0x0004009e,
+        0x000300a2, 0x000300a5, 0x000300a8, 0x000300ab, 0x000400ae, 0x000300b2, 0x000400b5, 0x000400b9,
+        0x000300bd, 0x000400c0, 0x000400c4, 0x000400c8, 0x000400cc, 0x000400d0, 0x000500d4, 0x000400d9,
+        0x000400dd, 0x000500e1, 0x000400e6, 0x000500ea, 0x000400ef, 0x000500f3, 0x000500f8, 0x000500fd,
+        0x00050102, 0x00050107, 0x0005010c, 0x00060111, 0x00050117, 0x0006011c, 0x00060122, 0x00060128,
+        0x0006012e, 0x00060134, 0x0006013a, 0x00070140, 0x00060147, 0x0007014d, 0x00060154, 0x0007015a,
+        0x00070161, 0x00060168, 0x0008016e, 0x00070176, 0x0008017d, 0x00080185, 0x0007018d, 0x00090194,
+        0x0008019d, 0x000801a5, 0x000801ad, 0x000901b5, 0x000901be, 0x000901c7, 0x000901d0, 0x000901d9,
+        0x000a01e2, 0x000901ec, 0x000a01f5, 0x000b01ff, 0x000a020a, 0x000b0214, 0x000a021f, 0x000b0229,
+        0x000b0234, 0x000b023f, 0x000c024a, 0x000c0256, 0x000c0262, 0x000c026e, 0x000c027a, 0x000d0286,
+        0x000d0293, 0x000d02a0, 0x000e02ad, 0x000e02bb, 0x000e02c9, 0x000e02d7, 0x000f02e5, 0x000f02f4,
+        0x000f0303, 0x000f0312, 0x00100321, 0x00100331, 0x00110341, 0x00100352, 0x00120362, 0x00110374,
+        0x00120385, 0x00120397, 0x001203a9, 0x001303bb, 0x001303ce, 0x001403e1, 0x001403f5, 0x00140409,
+        0x0015041d, 0x00150432, 0x00160447, 0x0016045d, 0x00160473, 0x00170489, 0x001704a0, 0x001904b7,
+        0x001804d0, 0x001904e8, 0x00190501, 0x001a051a, 0x001a0534, 0x001b054e, 0x001b0569, 0x001c0584,
+        0x001c05a0, 0x001d05bc, 0x001e05d9, 0x001e05f7, 0x001e0615, 0x00200633, 0x00200653, 0x00200673,
+        0x00210693, 0x002206b4, 0x002306d6, 0x002306f9, 0x0024071c, 0x00240740, 0x00260764, 0x0026078a,
+        0x002607b0, 0x002807d6, 0x002907fe, 0x00290827, 0x002a0850, 0x002a087a, 0x002c08a4, 0x002c08d0,
+        0x002e08fc, 0x002e092a, 0x002f0958, 0x00310987, 0x003109b8, 0x003209e9, 0x00330a1b, 0x00340a4e,
+        0x00350a82, 0x00350ab7, 0x00380aec, 0x00380b24, 0x003a0b5c, 0x003a0b96, 0x003c0bd0, 0x003d0c0c,
+        0x003e0c49, 0x003f0c87, 0x00400cc6, 0x00420d06, 0x00430d48, 0x00440d8b, 0x00460dcf, 0x00480e15,
+        0x00480e5d, 0x00490ea5, 0x004c0eee, 0x004d0f3a, 0x004e0f87, 0x00500fd5, 0x00511025, 0x00531076,
+        0x005610c9, 0x0056111f, 0x00581175, 0x005a11cd, 0x005c1227, 0x005e1283, 0x005e12e1, 0x0061133f,
+        0x006413a0, 0x00651404, 0x00671469, 0x006914d0, 0x006c1539, 0x006c15a5, 0x00701611, 0x00721681,
+        0x007416f3, 0x00761767, 0x007917dd, 0x007a1856, 0x007d18d0, 0x0080194d, 0x008319cd, 0x00841a50,
+        0x00881ad4, 0x00891b5c, 0x008d1be5, 0x00911c72, 0x00911d03, 0x00961d94, 0x00981e2a, 0x009c1ec2,
+        0x009e1f5e, 0x00a21ffc, 0x00a4209e, 0x00a92142, 0x00ab21eb, 0x00ae2296, 0x00b22344, 0x00b523f6,
+        0x00b924ab, 0x00be2564, 0x00c02622, 0x00c526e2, 0x00c827a7, 0x00cc286f, 0x00d0293b, 0x00d52a0b,
+        0x00d72ae0, 0x00dd2bb7, 0x00e12c94, 0x00e62d75, 0x00eb2e5b, 0x00ef2f46, 0x00f23035, 0x00f83127,
+        0x00fe321f, 0x0101331d, 0x0108341e, 0x010c3526, 0x01123632, 0x01173744, 0x011c385b, 0x01233977,
+        0x01273a9a, 0x012e3bc1, 0x01343cef, 0x013a3e23, 0x01403f5d, 0x0146409d, 0x014c41e3, 0x0154432f,
+        0x01594483, 0x016145dc, 0x0168473d, 0x016f48a5, 0x01764a14, 0x017d4b8a, 0x01854d07, 0x018d4e8c,
+        0x01945019, 0x019d51ad, 0x01a4534a, 0x01ad54ee, 0x01b5569b, 0x01be5850, 0x01c75a0e, 0x01d05bd5,
+        0x01d85da5, 0x01e35f7d, 0x01eb6160, 0x01f6634b, 0x01ff6541, 0x02096740, 0x02146949, 0x021e6b5d,
+        0x02296d7b, 0x02336fa4, 0x023f71d7, 0x024a7416, 0x02567660, 0x026278b6, 0x026d7b18, 0x027a7d85,
+    ]
+
+    ONE_OVER_ONE_PLUS_X_LUT = [
+        0xffc17fff, 0xffc07fc0, 0xffc27f80, 0xffc07f42, 0xffc17f02, 0xffc17ec3, 0xffc27e84, 0xffc27e46,
+        0xffc27e08, 0xffc37dca, 0xffc27d8d, 0xffc37d4f, 0xffc37d12, 0xffc37cd5, 0xffc37c98, 0xffc47c5b,
+        0xffc47c1f, 0xffc47be3, 0xffc57ba7, 0xffc57b6c, 0xffc37b31, 0xffc67af4, 0xffc57aba, 0xffc67a7f,
+        0xffc57a45, 0xffc67a0a, 0xffc779d0, 0xffc67997, 0xffc6795d, 0xffc77923, 0xffc778ea, 0xffc778b1,
+        0xffc87878, 0xffc77840, 0xffc87807, 0xffc877cf, 0xffc97797, 0xffc87760, 0xffc97728, 0xffc976f1,
+        0xffc976ba, 0xffc87683, 0xffca764b, 0xffca7615, 0xffca75df, 0xffca75a9, 0xffca7573, 0xffcb753d,
+        0xffca7508, 0xffcb74d2, 0xffcb749d, 0xffca7468, 0xffcc7432, 0xffcc73fe, 0xffcb73ca, 0xffcc7395,
+        0xffcd7361, 0xffcc732e, 0xffcc72fa, 0xffcd72c6, 0xffcd7293, 0xffcd7260, 0xffcc722d, 0xffce71f9,
+        0xffcd71c7, 0xffce7194, 0xffce7162, 0xffce7130, 0xffcf70fe, 0xffce70cd, 0xffce709b, 0xffcf7069,
+        0xffcf7038, 0xffcf7007, 0xffcf6fd6, 0xffcf6fa5, 0xffd06f74, 0xffd06f44, 0xffd06f14, 0xffd06ee4,
+        0xffd06eb4, 0xffd06e84, 0xffd16e54, 0xffd16e25, 0xffd16df6, 0xffd16dc7, 0xffd06d98, 0xffd26d68,
+        0xffd16d3a, 0xffd26d0b, 0xffd26cdd, 0xffd26caf, 0xffd26c81, 0xffd26c53, 0xffd36c25, 0xffd26bf8,
+        0xffd36bca, 0xffd36b9d, 0xffd36b70, 0xffd26b43, 0xffd46b15, 0xffd36ae9, 0xffd46abc, 0xffd46a90,
+        0xffd46a64, 0xffd46a38, 0xffd46a0c, 0xffd469e0, 0xffd469b4, 0xffd56988, 0xffd5695d, 0xffd56932,
+        0xffd56907, 0xffd568dc, 0xffd568b1, 0xffd56886, 0xffd6685b, 0xffd56831, 0xffd66806, 0xffd667dc,
+        0xffd667b2, 0xffd76788, 0xffd6675f, 0xffd76735, 0xffd6670c, 0xffd766e2, 0xffd666b9, 0xffd7668f,
+        0xffd86666, 0xffd6663e, 0xffd86614, 0xffd765ec, 0xffd865c3, 0xffd8659b, 0xffd86573, 0xffd8654b,
+        0xffd86523, 0xffd864fb, 0xffd964d3, 0xffd864ac, 0xffd96484, 0xffd8645d, 0xffd96435, 0xffd9640e,
+        0xffd963e7, 0xffd963c0, 0xffd96399, 0xffda6372, 0xffd9634c, 0xffda6325, 0xffda62ff, 0xffda62d9,
+        0xffda62b3, 0xffda628d, 0xffda6267, 0xffdb6241, 0xffda621c, 0xffdb61f6, 0xffda61d1, 0xffdc61ab,
+        0xffd96187, 0xffdc6160, 0xffdb613c, 0xffdb6117, 0xffdb60f2, 0xffdc60cd, 0xffdc60a9, 0xffdb6085,
+        0xffdc6060, 0xffdc603c, 0xffdc6018, 0xffdc5ff4, 0xffdc5fd0, 0xffdd5fac, 0xffdc5f89, 0xffdc5f65,
+        0xffdd5f41, 0xffdd5f1e, 0xffdd5efb, 0xffdd5ed8, 0xffdd5eb5, 0xffdd5e92, 0xffdd5e6f, 0xffdd5e4c,
+        0xffdd5e29, 0xffde5e06, 0xffde5de4, 0xffdd5dc2, 0xffde5d9f, 0xffde5d7d, 0xffde5d5b, 0xffde5d39,
+        0xffdf5d17, 0xffde5cf6, 0xffde5cd4, 0xffdf5cb2, 0xffdf5c91, 0xffde5c70, 0xffdf5c4e, 0xffdf5c2d,
+        0xffde5c0c, 0xffe05bea, 0xffdf5bca, 0xffdf5ba9, 0xffdf5b88, 0xffdf5b67, 0xffe05b46, 0xffe05b26,
+        0xffdf5b06, 0xffe05ae5, 0xffe05ac5, 0xffe05aa5, 0xffe05a85, 0xffe05a65, 0xffe05a45, 0xffe15a25,
+        0xffe05a06, 0xffe059e6, 0xffe159c6, 0xffe159a7, 0xffe05988, 0xffe15968, 0xffe15949, 0xffe1592a,
+        0xffe1590b, 0xffe158ec, 0xffe258cd, 0xffe158af, 0xffe15890, 0xffe25871, 0xffe15853, 0xffe25834,
+        0xffe25816, 0xffe257f8, 0xffe157da, 0xffe257bb, 0xffe3579d, 0xffe25780, 0xffe25762, 0xffe25744,
+        0xffe35726, 0xffe25709, 0xffe256eb, 0xffe356cd, 0xffe356b0, 0xffe35693, 0xffe25676, 0xffe35658,
+        0xffe3563b, 0xffe3561e, 0xffe35601, 0xffe355e4, 0xffe455c7, 0xffe355ab, 0xffe4558e, 0xffe35572,
+        0xffe45555, 0xffe35539, 0xffe4551c, 0xffe45500, 0xffe454e4, 0xffe454c8, 0xffe454ac, 0xffe45490,
+        0xffe45474, 0xffe55458, 0xffe4543d, 0xffe45421, 0xffe55405, 0xffe553ea, 0xffe453cf, 0xffe553b3,
+        0xffe45398, 0xffe5537c, 0xffe55361, 0xffe55346, 0xffe5532b, 0xffe55310, 0xffe552f5, 0xffe552da,
+        0xffe652bf, 0xffe552a5, 0xffe5528a, 0xffe6526f, 0xffe55255, 0xffe6523a, 0xffe65220, 0xffe55206,
+        0xffe651eb, 0xffe651d1, 0xffe651b7, 0xffe6519d, 0xffe65183, 0xffe65169, 0xffe7514f, 0xffe65136,
+        0xffe6511c, 0xffe75102, 0xffe650e9, 0xffe750cf, 0xffe650b6, 0xffe7509c, 0xffe75083, 0xffe6506a,
+        0xffe75050, 0xffe75037, 0xffe7501e, 0xffe75005, 0xffe74fec, 0xffe74fd3, 0xffe74fba, 0xffe74fa1,
+        0xffe84f88, 0xffe74f70, 0xffe84f57, 0xffe74f3f, 0xffe84f26, 0xffe74f0e, 0xffe84ef5, 0xffe84edd,
+        0xffe84ec5, 0xffe84ead, 0xffe74e95, 0xffe84e7c, 0xffe84e64, 0xffe94e4c, 0xffe84e35, 0xffe84e1d,
+        0xffe84e05, 0xffe94ded, 0xffe84dd6, 0xffe84dbe, 0xffe94da6, 0xffe94d8f, 0xffe84d78, 0xffe84d60,
+        0xffea4d48, 0xffe84d32, 0xffe94d1a, 0xffe94d03, 0xffe84cec, 0xffe94cd4, 0xffe94cbd, 0xffea4ca6,
+        0xffe94c90, 0xffe84c79, 0xffea4c61, 0xffe94c4b, 0xffe94c34, 0xffea4c1d, 0xffe94c07, 0xffea4bf0,
+        0xffe94bda, 0xffea4bc3, 0xffea4bad, 0xffe94b97, 0xffea4b80, 0xffea4b6a, 0xffea4b54, 0xffea4b3e,
+        0xffea4b28, 0xffea4b12, 0xffea4afc, 0xffea4ae6, 0xffea4ad0, 0xffeb4aba, 0xffea4aa5, 0xffea4a8f,
+        0xffeb4a79, 0xffea4a64, 0xffea4a4e, 0xffeb4a38, 0xffeb4a23, 0xffea4a0e, 0xffeb49f8, 0xffea49e3,
+        0xffeb49cd, 0xffeb49b8, 0xffeb49a3, 0xffeb498e, 0xffea4979, 0xffeb4963, 0xffeb494e, 0xffec4939,
+        0xffeb4925, 0xffea4910, 0xffec48fa, 0xffeb48e6, 0xffeb48d1, 0xffec48bc, 0xffeb48a8, 0xffec4893,
+        0xffeb487f, 0xffec486a, 0xffeb4856, 0xffec4841, 0xffec482d, 0xffeb4819, 0xffec4804, 0xffec47f0,
+        0xffec47dc, 0xffec47c8, 0xffec47b4, 0xffec47a0, 0xffec478c, 0xffec4778, 0xffec4764, 0xffec4750,
+        0xffec473c, 0xffed4728, 0xffec4715, 0xffec4701, 0xffed46ed, 0xffec46da, 0xffed46c6, 0xffec46b3,
+        0xffec469f, 0xffed468b, 0xffed4678, 0xffec4665, 0xffed4651, 0xffed463e, 0xffed462b, 0xffec4618,
+        0xffed4604, 0xffed45f1, 0xffed45de, 0xffed45cb, 0xffed45b8, 0xffed45a5, 0xffed4592, 0xffed457f,
+        0xffee456c, 0xffed455a, 0xffed4547, 0xffed4534, 0xffee4521, 0xffed450f, 0xffed44fc, 0xffee44e9,
+        0xffed44d7, 0xffee44c4, 0xffee44b2, 0xffed44a0, 0xffee448d, 0xffee447b, 0xffed4469, 0xffee4456,
+        0xffee4444, 0xffee4432, 0xffee4420, 0xffee440e, 0xffee43fc, 0xffee43ea, 0xffee43d8, 0xffee43c6,
+        0xffee43b4, 0xffee43a2, 0xffee4390, 0xffef437e, 0xffee436d, 0xffee435b, 0xffef4349, 0xffee4338,
+        0xffee4326, 0xffef4314, 0xffee4303, 0xffef42f1, 0xffee42e0, 0xffef42ce, 0xffee42bd, 0xffef42ab,
+        0xffef429a, 0xffee4289, 0xfff04277, 0xffee4267, 0xffef4255, 0xffef4244, 0xffef4233, 0xffef4222,
+        0xffee4211, 0xffef41ff, 0xfff041ee, 0xffef41de, 0xffef41cd, 0xffee41bc, 0xfff041aa, 0xffef419a,
+        0xffef4189, 0xffef4178, 0xfff04167, 0xffef4157, 0xffef4146, 0xfff04135, 0xffef4125, 0xfff04114,
+        0xffef4104, 0xfff040f3, 0xffef40e3, 0xfff040d2, 0xfff040c2, 0xffef40b2, 0xfff040a1, 0xfff04091,
+        0xfff04081, 0xffef4071, 0xfff04060, 0xfff04050, 0xfff04040, 0xfff04030, 0xfff04020, 0xfff04010
+    ]
+    # fmt: on
+
+    def __init__(self, op):
+        self.op = op
+
+    def get_graph(self):
+        ifm = self.op.inputs[0]
+        ofm = self.op.outputs[0]
+
+        if ifm.dtype == DataType.int16 and ofm.dtype == DataType.int16:
+            return self.get_graph_int16(ifm, ofm)
+        else:
+            self.op.run_on_npu = False
+            return self.op
+
+    def get_graph_int16(self, ifm, ofm):
+        ifm = TensorUtil.reshape(ifm, TensorUtil.get_full_shape(ifm.shape))
+        ofm = TensorUtil.reshape(ofm, TensorUtil.get_full_shape(ofm.shape), False)
+        no_scale_quant = ifm.quantization.clone()
+        no_scale_quant.scale_f32 = None
+
+        # PASS 0 - Depthwise Maxpool
+        maxpool_op = self.op.clone("_maxpool0")
+        maxpool_op.type = "MaxPool"
+        maxpool_h = ifm.shape[1] * ifm.shape[2]
+        maxpool_w = ifm.shape[3]
+        maxpool_ifm_shape = [1, maxpool_h, maxpool_w, 1]
+        maxpool_op.attrs["padding"] = b"VALID"
+        maxpool_op.attrs["stride_w"] = 1
+        maxpool_op.attrs["stride_h"] = 1
+        maxpool_op.attrs["filter_width"] = maxpool_w
+        maxpool_op.attrs["filter_height"] = 1
+        maxpool_op.attrs["strides"] = [1, maxpool_op.attrs["stride_h"], maxpool_op.attrs["stride_w"], 1]
+        maxpool_op.attrs["ksize"] = [1, maxpool_op.attrs["filter_height"], maxpool_op.attrs["filter_width"], 1]
+        maxpool_op.inputs = [TensorUtil.reshape(ifm, maxpool_ifm_shape)]
+        maxpool_ofm = Tensor([1, maxpool_h, 1, 1], DataType.int16, maxpool_op.name + "_0")
+        maxpool_ofm.quantization = no_scale_quant
+        TensorUtil.set_ofm_tensor(maxpool_op, maxpool_ofm)
+
+        # PASS 1 - Sub
+        sub1_op = Operation("SubAct", self.op.name + "_sub1")
+        TensorUtil.add_ifm_tensor(sub1_op, ifm)
+        TensorUtil.add_ifm_tensor(sub1_op, TensorUtil.reshape(maxpool_ofm, [1, ifm.shape[1], ifm.shape[2], 1]))
+        sub1_ofm = Tensor(ifm.shape, DataType.int32, sub1_op.name + "_0")
+        sub1_ofm.quantization = ifm.quantization.clone()
+        TensorUtil.set_ofm_tensor(sub1_op, sub1_ofm)
+
+        # PASS 2 - Mul
+        beta = self.op.attrs.get("beta", 1.0)
+        mul2_out_range = 10.0 / 65535.0
+        mul2_scale, _ = scaling.elementwise_mul_scale(sub1_ofm.quantization.scale_f32, beta, mul2_out_range)
+        mul2_quant = ifm.quantization.clone()
+        mul2_quant.scale_f32 = beta
+        mul2_op = Operation("MulAct", self.op.name + "_mul2")
+        TensorUtil.add_ifm_tensor(mul2_op, sub1_ofm)
+        TensorUtil.add_ifm_tensor(
+            mul2_op,
+            TensorUtil.create_const_tensor(
+                mul2_op.name + "_const", [1, 1, 1, 1], DataType.int32, [mul2_scale], np.uint32, quantization=mul2_quant
+            ),
+        )
+        mul2_ofm = Tensor(ifm.shape, DataType.int32, mul2_op.name + "_0")
+        mul2_ofm.quantization = ofm.quantization.clone()
+        mul2_ofm.quantization.scale_f32 = mul2_out_range
+        TensorUtil.set_ofm_tensor(mul2_op, mul2_ofm)
+
+        # PASS 3 - Add+LUT(exp)
+        add_op = Operation("AddAct", self.op.name + "_add3")
+        TensorUtil.add_ifm_tensor(add_op, mul2_ofm)
+        TensorUtil.add_ifm_tensor(
+            add_op,
+            TensorUtil.create_const_tensor(
+                add_op.name + "_const", [1, 1, 1, 1], DataType.int32, [32767], np.uint32, quantization=no_scale_quant
+            ),
+        )
+        add_op.set_activation_lut(
+            TensorUtil.create_const_tensor(
+                add_op.name + "_lut", [1, 1, 1, 512], DataType.int32, self.EXP_LUT, np.uint32, TensorPurpose.LUT
+            )
+        )
+        exp_ofm = Tensor(mul2_ofm.shape, DataType.int16, add_op.name + "_0")
+        exp_ofm.quantization = mul2_ofm.quantization.clone()
+        TensorUtil.set_ofm_tensor(add_op, exp_ofm)
+
+        # PASS 4 - Reduce sum
+        reduce_sum_op = Operation("ReduceSum", self.op.name + "_reduce_sum4")
+        reduce_sum_op.attrs["padding"] = b"VALID"
+        reduce_sum_op.attrs["stride_w"] = 1
+        reduce_sum_op.attrs["stride_h"] = 1
+        reduce_sum_op.attrs["filter_width"] = 1
+        reduce_sum_op.attrs["filter_height"] = 1
+        reduce_sum_op.attrs["strides"] = [1, reduce_sum_op.attrs["stride_h"], reduce_sum_op.attrs["stride_w"], 1]
+        reduce_sum_op.attrs["ksize"] = [1, reduce_sum_op.attrs["filter_height"], reduce_sum_op.attrs["filter_width"], 1]
+        TensorUtil.add_ifm_tensor(reduce_sum_op, exp_ofm)
+
+        reduce_sum_shape = [1, exp_ofm.shape[1], exp_ofm.shape[2], 1]
+        sum_of_exp = Tensor(reduce_sum_shape, DataType.int32, reduce_sum_op.name + "_0")
+        sum_of_exp.quantization = no_scale_quant
+        TensorUtil.set_ofm_tensor(reduce_sum_op, sum_of_exp)
+
+        # PASS 5 - CLZ
+        clz_op = Operation("CLZ", self.op.name + "_clz5")
+        TensorUtil.add_ifm_tensor(clz_op, sum_of_exp)
+        headroom_plus_one = Tensor(reduce_sum_shape, DataType.int32, clz_op.name + "_0")
+        headroom_plus_one.quantization = no_scale_quant
+        TensorUtil.set_ofm_tensor(clz_op, headroom_plus_one)
+
+        # PASS 6 - Sub
+        sub6_op = Operation("SubAct", self.op.name + "_sub6")
+        TensorUtil.add_ifm_tensor(sub6_op, headroom_plus_one)
+        TensorUtil.add_ifm_tensor(
+            sub6_op,
+            TensorUtil.create_const_tensor(
+                sub6_op.name + "_const", [1, 1, 1, 1], DataType.int32, [31], np.uint32, quantization=no_scale_quant
+            ),
+        )
+        # TODO: Adding this attribute to reverse the operand order is not ideal
+        #       it should be handled automatically by register_command_stream_generator
+        #       or added as an internal operator.
+        sub6_op.attrs["reverse_op_order"] = True
+        reciprocal_right_shift = Tensor(reduce_sum_shape, DataType.int32, sub6_op.name + "_0")
+        reciprocal_right_shift.quantization = no_scale_quant
+        TensorUtil.set_ofm_tensor(sub6_op, reciprocal_right_shift)
+
+        # PASS 7 - SHL
+        shl7_op = Operation("SHL", self.op.name + "_shl7")
+        TensorUtil.add_ifm_tensor(shl7_op, reciprocal_right_shift)
+        TensorUtil.add_ifm_tensor(
+            shl7_op,
+            TensorUtil.create_const_tensor(
+                shl7_op.name + "_const", [1, 1, 1, 1], DataType.int32, [1], np.uint32, quantization=no_scale_quant
+            ),
+        )
+        # TODO: See above
+        shl7_op.attrs["reverse_op_order"] = True
+        constant_one = Tensor(reduce_sum_shape, DataType.int32, shl7_op.name + "0")
+        constant_one.quantization = no_scale_quant
+        TensorUtil.set_ofm_tensor(shl7_op, constant_one)
+
+        # PASS 8 - Sub
+        sub8_op = Operation("SubAct", self.op.name + "_sub8")
+        TensorUtil.add_ifm_tensor(sub8_op, sum_of_exp)
+        TensorUtil.add_ifm_tensor(sub8_op, constant_one)
+        sum_of_exps_minus_one = Tensor(reduce_sum_shape, DataType.int32, sub8_op.name + "_0")
+        sum_of_exps_minus_one.quantization = no_scale_quant
+        TensorUtil.set_ofm_tensor(sub8_op, sum_of_exps_minus_one)
+
+        # PASS 9 - SHL
+        shl9_op = Operation("SHL", self.op.name + "_shl9")
+        TensorUtil.add_ifm_tensor(shl9_op, sum_of_exps_minus_one)
+        TensorUtil.add_ifm_tensor(shl9_op, headroom_plus_one)
+        shifted_sum_minus_one = Tensor(reduce_sum_shape, DataType.int32, shl9_op.name + "_0")
+        shifted_sum_minus_one.quantization = no_scale_quant
+        TensorUtil.set_ofm_tensor(shl9_op, shifted_sum_minus_one)
+
+        # PASS 10 - SHR
+        shr10_op = Operation("SHR", self.op.name + "_shr10")
+        TensorUtil.add_ifm_tensor(shr10_op, shifted_sum_minus_one)
+        TensorUtil.add_ifm_tensor(
+            shr10_op,
+            TensorUtil.create_const_tensor(
+                shr10_op.name + "_const", [1, 1, 1, 1], DataType.int32, [15], np.uint32, quantization=no_scale_quant
+            ),
+        )
+        shifted_sum_minus_one_16 = Tensor(reduce_sum_shape, DataType.int32, shr10_op.name + "_0")
+        shifted_sum_minus_one_16.quantization = shifted_sum_minus_one.quantization.clone()
+        TensorUtil.set_ofm_tensor(shr10_op, shifted_sum_minus_one_16)
+
+        # PASS 11 - Sub+LUT(one over one plus x)
+        sub11_op = Operation("SubAct", self.op.name + "_sub11")
+        TensorUtil.add_ifm_tensor(sub11_op, shifted_sum_minus_one_16)
+        TensorUtil.add_ifm_tensor(
+            sub11_op,
+            TensorUtil.create_const_tensor(
+                sub11_op.name + "_const", [1, 1, 1, 1], DataType.int32, [32768], np.uint32, quantization=no_scale_quant
+            ),
+        )
+        sub11_op.set_activation_lut(
+            TensorUtil.create_const_tensor(
+                sub11_op.name + "_lut",
+                [1, 1, 1, 512],
+                DataType.int32,
+                self.ONE_OVER_ONE_PLUS_X_LUT,
+                np.uint32,
+                TensorPurpose.LUT,
+            )
+        )
+        reciprocal_scale = Tensor(reduce_sum_shape, DataType.int16, sub11_op.name + "_0")
+        reciprocal_scale.quantization = no_scale_quant
+        TensorUtil.set_ofm_tensor(sub11_op, reciprocal_scale)
+
+        # PASS 12 - Multiply
+        mul_op = Operation("MulAct", self.op.name + "_mul12")
+        TensorUtil.add_ifm_tensor(mul_op, exp_ofm)
+        TensorUtil.add_ifm_tensor(mul_op, reciprocal_scale)
+        mul_ofm = Tensor(exp_ofm.shape, DataType.int32, mul_op.name + "_0")
+        mul_ofm.quantization = no_scale_quant
+        TensorUtil.set_ofm_tensor(mul_op, mul_ofm)
+
+        # PASS 13 - SHR
+        shr13_op = Operation("SHR", self.op.name + "_shr13")
+        TensorUtil.add_ifm_tensor(shr13_op, mul_ofm)
+        TensorUtil.add_ifm_tensor(shr13_op, reciprocal_right_shift)
+        TensorUtil.set_ofm_tensor(shr13_op, ofm)
+
+        return shr13_op
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py
index 3ec3429..73e219b 100644
--- a/ethosu/vela/supported_operators.py
+++ b/ethosu/vela/supported_operators.py
@@ -20,19 +20,20 @@
 
 
 class SupportedOperators:
-    def __init__(self):
+    def __init__(self, softmax_support):
+        self.softmax_support = softmax_support
         # Categorised lists of supported operators
-        self.npu_pre_ops = set(("QuantizedResizeBilinear", "SplitSliceRead"))
-        self.convolution_ops = set(("Conv2DBiasAct", "Conv2D", "QuantizedConv2D"))
+        self.npu_pre_ops = set(("QuantizedResizeBilinear", "SplitSliceRead",))
+        self.convolution_ops = set(("Conv2DBiasAct", "Conv2D", "QuantizedConv2D",))
         self.depthwise_convolution_ops = set(
-            ("DepthwiseConv2dBiasAct", "DepthwiseConv2dNative", "QuantizedDepthwiseConv2D")
+            ("DepthwiseConv2dBiasAct", "DepthwiseConv2dNative", "QuantizedDepthwiseConv2D,")
         )
         self.transpose_convolution_ops = set(("Conv2DBackpropInput",))
-        self.max_pooling_ops = set(("QuantizedMaxPool", "MaxPool", "MaxPoolAct"))
-        self.avg_pooling_ops = set(("QuantizedAvgPool", "AvgPool", "AvgPoolAct"))
-        self.pooling_ops = self.max_pooling_ops | self.avg_pooling_ops
+        self.max_pooling_ops = set(("QuantizedMaxPool", "MaxPool", "MaxPoolAct",))
+        self.avg_pooling_ops = set(("QuantizedAvgPool", "AvgPool", "AvgPoolAct",))
+        self.pooling_ops = set(("ReduceSum",)) | self.max_pooling_ops | self.avg_pooling_ops
         self.resizing_ops = set(("ResizeBilinear",))
-        self.fc_vector_products = set(("QuantizedMatMul", "MatMul", "FullyConnectedAct"))
+        self.fc_vector_products = set(("QuantizedMatMul", "MatMul", "FullyConnectedAct",))
         self.mac_main_ops = (
             # convolutions
             self.convolution_ops
@@ -47,34 +48,56 @@
             # FC layers
             | self.fc_vector_products
             # RNN/LSTM/GRU
-            | set(("BlockLSTM"))
+            | set(("BlockLSTM",))
         )
-        self.unary_elem_wise_main_ops = set(("LeakyRelu", "Abs"))
-        self.binary_elem_wise_min_max_ops = set(("Minimum", "Maximum"))
+        self.unary_elem_wise_main_ops = set(("LeakyRelu", "Abs", "CLZ",))
+        self.binary_elem_wise_min_max_ops = set(("Minimum", "Maximum",))
         self.binary_elem_wise_add_mul_sub = set(
-            ("AddAct", "MulAct", "SubAct", "QuantizedAdd", "QuantizedSub", "QuantizedMul", "Mul", "Add", "Sub",)
+            (
+                "AddAct",
+                "MulAct",
+                "SubAct",
+                "QuantizedAdd",
+                "QuantizedSub",
+                "QuantizedMul",
+                "Mul",
+                "Add",
+                "Sub",
+                "SHL",
+                "SHR",
+            )
         )
         self.binary_elem_wise_main_ops = self.binary_elem_wise_min_max_ops | self.binary_elem_wise_add_mul_sub
         self.elem_wise_main_ops = self.binary_elem_wise_main_ops | self.unary_elem_wise_main_ops
         self.activation_ops = set(
-            ("QuantizedRelu", "QuantizedRelu1", "QuantizedRelu6", "Relu", "Relu6", "ReluN1To1", "Sigmoid", "Tanh")
+            (
+                "QuantizedRelu",
+                "QuantizedRelu1",
+                "QuantizedRelu6",
+                "Relu",
+                "Relu6",
+                "ReluN1To1",
+                "Sigmoid",
+                "Tanh",
+                "Softmax",
+            )
         )
         self.npu_post_ops = (
             # activation functions
             self.activation_ops
             # concatenation write direction
-            | set(("ConcatSliceWrite"))
+            | set(("ConcatSliceWrite",))
             # bias add and batch norm
-            | set(("QuantizedBiasAdd", "Requantize", "QuantizedBatchNorm", "BiasAdd", "FusedBatchNorm"))
+            | set(("QuantizedBiasAdd", "Requantize", "QuantizedBatchNorm", "BiasAdd", "FusedBatchNorm",))
             # Quantization
             | set(("Quantize",))
         )
-        self.split_ops = set(("Split", "SplitV", "StridedSlice", "Slice", "UnpackReshaped", "Unpack"))
-        self.concat_ops = set(("Concat", "ConcatV2", "QuantizedConcat", "ConcatTFLite", "PackReshaped", "Pack"))
+        self.split_ops = set(("Split", "SplitV", "StridedSlice", "Slice", "UnpackReshaped", "Unpack",))
+        self.concat_ops = set(("Concat", "ConcatV2", "QuantizedConcat", "ConcatTFLite", "PackReshaped", "Pack",))
         self.memory_only_ops = (
-            set(("Squeeze", "Reshape", "QuantizedReshape", "ExpandDims")) | self.concat_ops | self.split_ops
+            set(("Squeeze", "Reshape", "QuantizedReshape", "ExpandDims",)) | self.concat_ops | self.split_ops
         )
-        self.supported_fused_activations = set(("Relu", "Relu6", "ReluN1To1", "Tanh", "Sigmoid"))
+        self.supported_fused_activations = set(("Relu", "Relu6", "ReluN1To1", "Tanh", "Sigmoid", "LUT",))
         self.supported_operators = (
             self.npu_pre_ops | self.mac_main_ops | self.elem_wise_main_ops | self.npu_post_ops | self.memory_only_ops
         )
@@ -103,6 +126,7 @@
         self.supported_operator_restrictions.update(
             {op: self.check_quantization_restrictions for op in self.binary_elem_wise_min_max_ops}
         )
+        self.supported_operator_restrictions.update({op: self.check_activation_ops for op in self.activation_ops})
 
     def is_operator_supported(self, op):
         if op.type not in self.supported_operators:
@@ -127,7 +151,10 @@
         for t in tensors:
             if not (t.dtype.type & BaseType.Int):
                 return False
-            if t.element_size() > 2 and op.type not in ("Requantize") | self.binary_elem_wise_add_mul_sub:
+            if (
+                t.element_size() > 2
+                and op.type not in set(("Requantize", "ReduceSum", "CLZ",)) | self.binary_elem_wise_add_mul_sub
+            ):
                 return False
             # check size
             if any(dim > 65536 for dim in t.shape):
@@ -212,7 +239,9 @@
         # check data type
         ifm_tensor, _, _, ofm_tensor = op.get_ifm_ifm2_weights_ofm()
         if ifm_tensor.dtype != ofm_tensor.dtype:
-            return False
+            if op.type != "ReduceSum":
+                return False
+            # TODO: else check ReduceSum restrictions.
 
         # check batch size
         if ifm_tensor.shape[0] != 1:
@@ -309,9 +338,33 @@
 
     def check_quantization_restrictions(self, op):
         # makes sure IFM1, IFM2 and OFM quantization are equal for binary ops
-        if (len(op.inputs) == 2
-            and not op.inputs[0].quantization == op.inputs[1].quantization == op.outputs[0].quantization):
-            print("Warning: Input/output tensors with different quantization is unsupported for the", op.type,
-                  "operator")
+        if (
+            len(op.inputs) == 2
+            and not op.inputs[0].quantization == op.inputs[1].quantization == op.outputs[0].quantization
+        ):
+            print(
+                "Warning: Input/output tensors with different quantization is unsupported for the", op.type, "operator"
+            )
             return False
-        return True
\ No newline at end of file
+        return True
+
+    def check_activation_ops(self, op):
+        if op.type == "Softmax":
+            if not self.softmax_support:
+                return False
+
+            ifm_tensor = op.inputs[0]
+            ofm_tensor = op.outputs[0]
+
+            # check data type
+            if ifm_tensor.dtype != ofm_tensor.dtype:
+                return False
+
+            if ifm_tensor.dtype != DataType.int16:
+                return False  # TODO: Implement support for 8-bit Softmax
+
+            # check batch size
+            if len(ifm_tensor.shape) in (2, 4) and ifm_tensor.shape[0] != 1:
+                return False
+
+        return True
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py
index 1a071e6..c2d6b6e 100644
--- a/ethosu/vela/tensor.py
+++ b/ethosu/vela/tensor.py
@@ -73,13 +73,14 @@
     Weights = 1
     FeatureMap = 2
     Scratch = 3
-    Size = 4
+    LUT = 4
+    Size = 5
 
     def display_name(self):
-        return ("Unknown", "Weights", "FeatureMap", "Scratch", "Size")[self.value]
+        return ("Unknown", "Weights", "FeatureMap", "Scratch", "LUT", "Size")[self.value]
 
     def identifier_name(self):
-        return ("unknown", "weights", "feature_map", "scratch", "size")[self.value]
+        return ("unknown", "weights", "feature_map", "scratch", "lut", "size")[self.value]
 
     def all():
         return (TensorPurpose.Weights, TensorPurpose.FeatureMap)
diff --git a/ethosu/vela/vela.py b/ethosu/vela/vela.py
index 1766750..b1edf34 100644
--- a/ethosu/vela/vela.py
+++ b/ethosu/vela/vela.py
@@ -253,6 +253,13 @@
         choices=[True, False],
         help="Control if NHCWB16 or NHWC should be used in between cascaded passes (default: %(default)s)",
     )
+    parser.add_argument(
+        "--softmax-support",
+        type=ast.literal_eval,
+        default=False,
+        choices=[True, False],
+        help="Control if Softmax should be transformed into a set of npu operations (default: %(default)s)",
+    )
 
     args = parser.parse_args(args=args)
 
@@ -283,6 +290,7 @@
         block_config_limit=args.block_config_limit,
         global_memory_clock_scale=args.global_memory_clock_scale,
         max_blockdep=args.max_block_dependency,
+        softmax_support=args.softmax_support,
     )
 
     compiler_options = compiler_driver.CompilerOptions(