MLBEDSW-6435: Implement support for ArgMax along depth dimension

- Add support for ArgMax along depth dimension with a depth limit of 127.
- Only supports 8-bit input and 32-bit output

Signed-off-by: Rickard Bolin <rickard.bolin@arm.com>
Change-Id: I5f6f0503135bebabbb1ca637f9729587b7c60740
diff --git a/SUPPORTED_OPS.md b/SUPPORTED_OPS.md
index 860d1fe..6f6167d 100644
--- a/SUPPORTED_OPS.md
+++ b/SUPPORTED_OPS.md
@@ -1,7 +1,7 @@
 # Supported Ops
 
 This file was automatically generated by Vela using the `--supported-ops-report` parameter.  
-Vela version: `3.7.1.dev2+g19f8967.d20230301`
+Vela version: `3.7.1.dev8+ga182a70.d20230322`
 
 This file complies with
 [**Gitiles Markdown syntax**](https://github.com/google/gitiles/blob/master/Documentation/markdown.md)
@@ -20,6 +20,7 @@
 | --- | --- |
 | ABS | [Generic](#tflite-generic-constraints), [Specific](#tflite-abs-constraints) |
 | ADD | [Generic](#tflite-generic-constraints), [Specific](#tflite-add-constraints) |
+| ARG_MAX | [Generic](#tflite-generic-constraints), [Specific](#tflite-arg_max-constraints) |
 | AVERAGE_POOL_2D | [Generic](#tflite-generic-constraints), [Specific](#tflite-average_pool_2d-constraints) |
 | CONCATENATION | [Generic](#tflite-generic-constraints), [Specific](#tflite-concatenation-constraints) |
 | CONV_2D | [Generic](#tflite-generic-constraints), [Specific](#tflite-conv_2d-constraints) |
@@ -64,14 +65,14 @@
 - Input(s) and Output tensors must not be dynamic - [QUANTIZE]
 - Input(s) and Output tensors must have a defined shape
 - Output tensors cannot be scalar - [QUANTIZE]
-- Scalar Input tensors are only valid for op type: ADD, EXPAND_DIMS, MAXIMUM, MEAN, MINIMUM, MUL, QUANTIZE, SPLIT, SPLIT_V, SUB
+- Scalar Input tensors are only valid for op type: ADD, ARG_MAX, EXPAND_DIMS, MAXIMUM, MEAN, MINIMUM, MUL, QUANTIZE, SPLIT, SPLIT_V, SUB
 - Input(s) and Output tensors must not be greater than 4D
-- Input(s), Output and Weight tensors must have quantization parameters - [SHAPE]
+- Input(s), Output and Weight tensors must have quantization parameters - [ARG_MAX, SHAPE]
 - Input(s), Output and Weight tensors with quantization scales must be finite
 - Input and Output tensors must have quantization scales that fit within float32 precision
 - Constant tensors should not have NoneType-values
 - Tensors must be of type: int16, int32, int8, uint8
-- Tensors which are int32 are only valid when op type is: ADD, MUL, SHAPE, SUB
+- Tensors which are int32 are only valid when op type is: ADD, ARG_MAX, MUL, SHAPE, SUB
 - Tensor dimensions must be in the range [1, 65535]
 - Per-axis quantization is only supported for the following op types: CONV_2D, DEPTHWISE_CONV_2D, TRANSPOSE_CONV
 - IFM Tensor batch size must be 1 - [FULLY_CONNECTED, RESHAPE, SHAPE, SLICE, SOFTMAX, SPLIT, SPLIT_V, SQUEEZE, STRIDED_SLICE, UNPACK]
@@ -95,6 +96,15 @@
 - For IFM that are unsigned, OFM must either be the same type or int32
 - Broadcasting is only allowed for rank indices with dimension 1, from either IFM1 or IFM2
 
+### TFLite ARG_MAX Constraints
+
+This is a list of constraints that the ARG_MAX operator must satisfy in order to be scheduled on the NPU.
+
+- IFM must be int8 or uint8
+- Number of input dimensions must be 4
+- Operation must be performed along the depth axis
+- IFM depth must be no greater than 127
+
 ### TFLite AVERAGE_POOL_2D Constraints
 
 This is a list of constraints that the AVERAGE_POOL_2D operator must satisfy in order to be scheduled on the NPU.
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index 6be9dc2..6771710 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -134,7 +134,7 @@
     Add = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_IFM2_INDICES)
     AddN = OperatorInfo()
     Any = OperatorInfo()
-    ArgMax = OperatorInfo()
+    ArgMax = OperatorInfo(indices=NNG_IFM_INDICES)
     ArgMin = OperatorInfo()
     AvgPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=NNG_IFM_INDICES)
     Atan2 = OperatorInfo(indices=NNG_IFM_IFM2_INDICES)
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index a1cbb3e..44f5d6a 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -50,6 +50,7 @@
 from .operation import Padding
 from .operation_util import create_add_nop
 from .operation_util import create_avgpool_nop
+from .operation_util import create_depthwise_maxpool
 from .operation_util import get_pad_values_from_input
 from .scaling import quantise_scale
 from .shape4d import Shape4D
@@ -460,6 +461,161 @@
     return op
 
 
+def convert_argmax_to_depthwise_conv_and_max_pool(op, arch, nng):
+    """
+    Convert ArgMax to DWConv2D->MaxPool->DWConv2D, see details below.
+
+    Example:
+    arr = [4,   [00000100,
+           6, =  00000110,  # <-- This is the largest value, so we're expecting argmax(arr) = 1
+           5]    00000101]
+
+    Use 16-bit precision and shift all values 7 bits to the left:
+    Shifted_arr = [0000001000000000,
+                   0000001100000000,
+                   0000001010000000]
+
+    Add "c - index of channel" to each channel:
+    Shifted_arr_plus_reverse_idx = [0000001000000010, (+2)
+                                    0000001100000001, (+1)
+                                    0000001010000000] (+0)
+
+    The index is reversed since ArgMax selects the lowest index if maximum value is found at two index. The index will
+    act as a tie-breaker between channels with equal values and since we want the smallest channel index to be chosen
+    we reverse the index before the maxpool and then subtract the index from the number of channel after the maxpool to
+    get the correct index.
+
+    Find the maximum value in the array:
+    val = max(shifted_arr_plus_reverse_idx) = 0000001100000001
+
+    Subtract the value from the number of channels:
+    shifted_arr_plus_idx = (c-1) - val = 2 - 1 = 1
+
+    Extract the 7 lowest bits using a LUT to cut off the 9 most significant bits:
+    idx = LUT(val) = 0000000000000001 = 1
+    """
+
+    if op.type == Op.ArgMax:
+        ifm, ofm = op.inputs[0], op.outputs[0]
+        identity_quant = QuantizationParameters()
+        identity_quant.zero_point = 0
+        identity_quant.scale_f32 = 1.0
+        if ofm.quantization is None:
+            ofm.quantization = identity_quant
+        # Add last dimension to ofm shape
+        ofm.shape += [1]
+        ofm.ops = []
+
+        # Create 1x1 Depthwise convolution with 2**7 weights for each channel to convert precision to 16 bit and shift
+        # all values 7 bits to the left
+        # Set necessary depthwise attributes
+        dw_op_attrs = {
+            "padding": Padding.VALID,
+            "stride_h": 1,
+            "stride_w": 1,
+            "strides": (1, 1, 1, 1),
+            "depth_multiplier": 1,
+            "channel_multiplier": 1,
+            "dilation_h_factor": 1,
+            "dilation_w_factor": 1,
+            "dilation": (1, 1, 1, 1),
+            "explicit_padding": None,
+        }
+        op.name = "depthwise_conv_SHL_7"
+        op.type = Op.DepthwiseConv2DBias
+        op.attrs.update(dw_op_attrs)
+        n, h, w, c = ifm.shape
+        shape = [1, 1, 1, c]
+        kernel = np.dstack([2**7] * c)
+        op.inputs = []
+        op.add_input_tensor(ifm)
+        op.add_input_tensor(
+            create_const_tensor(
+                "weights",
+                shape,
+                DataType.uint8,
+                np.array(kernel).reshape(shape),
+                quantization=identity_quant,
+            ),
+        )
+        # Let the bias for each channel be the "reverse" index of the channel it is in, ie c - channel_idx
+        reverse_idxs = list(reversed(range(c)))
+        bias_tensor = create_const_tensor(op.name + "_bias", [c], DataType.int64, reverse_idxs)
+        op.add_input_tensor(bias_tensor)
+
+        intermediate_tens = Tensor([n, h, w, c], DataType.int16, "int16_and_shifted_7_bits_left")
+        intermediate_tens.quantization = ifm.quantization
+        op.set_output_tensor(intermediate_tens)
+        op.set_ifm_ofm_shapes()
+        orig_ifm_shape = op.ifm_shapes[0]
+        DebugDatabase.add_optimised(op, op)
+
+        # To extract 7 least significant bits and swap reverse index back to real index using a LUT activation, we set
+        # the base value to c-1 and slope to -128. The 16-bit LUT uses a table of 32-bit values where the top 16 bits
+        # represent the slope and bottom 16 bits the base which are used to interpolate the activation value.
+        slope = (-128 & 0xFFFF) << 16  # Top 16 bits of 32 bit LUT table value
+        base = c - 1  # Bottom 16 bits of the LUT table value
+        lut_tensor = create_const_tensor(
+            "maxpool_LUT_extract_7_LSB",
+            [1, 1, 1, 512],
+            DataType.uint32,
+            [slope + base] * 512,
+            TensorPurpose.LUT,
+        )
+
+        # Split large feature maps into smaller chunks since the Depthwise Maxpool height dimension can overflow due to
+        # flattening the ifm to (H*W)xCx1
+        max_height = 2**16 // orig_ifm_shape.width
+        num_full_height_ops = orig_ifm_shape.height // max_height
+        last_op_height = orig_ifm_shape.height - max_height * num_full_height_ops
+        op_heights = [max_height] * num_full_height_ops
+        if last_op_height > 0:
+            op_heights.append(last_op_height)
+
+        # Create maxpool output tensor which is reshaped to 1x(H*W)x1x1. The product H*W might be larger than the
+        # maximum allowed height, but that's handled by reading and writing the data in chunks
+        maxpool_ofm = Tensor([1, orig_ifm_shape.height * orig_ifm_shape.width, 1, 1], DataType.int16, "argmax_maxpool")
+        maxpool_ofm.quantization = identity_quant
+
+        for op_idx, op_height in enumerate(op_heights):
+            maxpool_op = create_depthwise_maxpool(
+                f"dw_maxpool_{op_idx}", intermediate_tens, orig_ifm_shape, identity_quant
+            )
+            maxpool_op.outputs = [maxpool_ofm]
+            maxpool_ofm.ops.append(maxpool_op)
+            maxpool_op.ofm_shapes = [Shape4D(maxpool_ofm.shape)]
+            maxpool_op.set_activation_lut(lut_tensor)
+
+            # Set read and write shapes/offsets to read/write chunks of the IFM/OFM
+            maxpool_op.read_shapes[0] = Shape4D([1, op_height * orig_ifm_shape.width, orig_ifm_shape.depth, 1])
+            maxpool_op.read_offsets[0] = Shape4D([0, sum(op_heights[:op_idx]) * orig_ifm_shape.width, 0, 0])
+            maxpool_op.write_shape = Shape4D([1, op_height * orig_ifm_shape.width, 1, 1])
+            maxpool_op.write_offset = Shape4D([0, sum(op_heights[:op_idx]) * orig_ifm_shape.width, 0, 0])
+            DebugDatabase.add_optimised(op, maxpool_op)
+
+        # Convert output to OFM dtype and reshape back to original OFM shape with 1x1 DWConv
+        dw_conv = Operation(Op.DepthwiseConv2DBias, f"depthwise_conv_convert_to_32bit_{op_idx}")
+        dw_conv.attrs.update(dw_op_attrs)
+        dw_conv.inputs = [maxpool_op.ofm]
+        dw_conv.add_input_tensor(
+            create_const_tensor(
+                "weights",
+                [1, 1, 1, 1],
+                DataType.uint8,
+                np.array([1]).reshape([1, 1, 1, 1]),
+                quantization=identity_quant,
+            ),
+        )
+        dw_conv.add_input_tensor(create_const_tensor(dw_conv.name + "_bias", [1], DataType.int64, [0]))
+        ofm.ops.append(dw_conv)
+        dw_conv.outputs = [ofm]
+        dw_conv.ifm_shapes.append(Shape4D([1, orig_ifm_shape.height, orig_ifm_shape.width, 1]))
+        dw_conv.ofm_shapes.append(Shape4D(ofm.shape))
+        DebugDatabase.add_optimised(op, dw_conv)
+
+    return op
+
+
 def convert_resizebilinear_to_depthwise_convolutions(op, half_pixel_centers=True):
     def _compute_interpolation_values(index, input_size, output_size):
         scale = input_size / output_size
@@ -1976,6 +2132,7 @@
         fixup_conv2d_backprop,
         fixup_relus_with_differing_ifm_ofm_scaling,
         reorder_depthwise_weights,
+        convert_argmax_to_depthwise_conv_and_max_pool,
         fixup_resize,
         fixup_bias_tensors,
         fixup_asymmetric_weights,
diff --git a/ethosu/vela/tflite_mapping.py b/ethosu/vela/tflite_mapping.py
index 8ec0173..98fe287 100644
--- a/ethosu/vela/tflite_mapping.py
+++ b/ethosu/vela/tflite_mapping.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2020-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
 #
 # SPDX-License-Identifier: Apache-2.0
 #
@@ -749,7 +749,7 @@
     BuiltinOperator.ARG_MAX: (
         Op.ArgMax,
         OptionsSerializer("ArgMaxOptions", (("output_type", datatype_deserialize, datatype_serialize),)),
-        TFLITE_NO_INDICES,
+        TFLITE_IFM_INDICES,
     ),
     BuiltinOperator.MINIMUM: (Op.Minimum, OptionsSerializer("MaximumMinimumOptions"), TFLITE_IFM_IFM2_INDICES),
     BuiltinOperator.LESS: (Op.Less, OptionsSerializer("LessOptions"), TFLITE_NO_INDICES),
diff --git a/ethosu/vela/tflite_model_semantic.py b/ethosu/vela/tflite_model_semantic.py
index 9f53a1e..495d71a 100644
--- a/ethosu/vela/tflite_model_semantic.py
+++ b/ethosu/vela/tflite_model_semantic.py
@@ -77,7 +77,9 @@
     )
     binary_elem_wise_main_ops = binary_elem_wise_min_max_ops | binary_elem_wise_add_mul_sub | binary_elem_wise_shift_ops
     elem_wise_main_ops = binary_elem_wise_main_ops | unary_elem_wise_main_ops
-    shapeless_input_ops = binary_elem_wise_main_ops | set((Op.Split, Op.SplitV, Op.Mean, Op.ExpandDims, Op.Quantize))
+    shapeless_input_ops = binary_elem_wise_main_ops | set(
+        (Op.Split, Op.SplitV, Op.Mean, Op.ExpandDims, Op.Quantize, Op.ArgMax)
+    )
     reshape_ops = set(
         (
             Op.Reshape,
@@ -187,6 +189,9 @@
         self.specific_constraints[Op.Mean].append(TFLiteSemantic.constraint_mean_input_dims)
         self.specific_constraints[Op.Mean].append(TFLiteSemantic.constraint_mean_axis)
 
+        # ArgMax specific checks:
+        self.specific_constraints[Op.ArgMax].append(TFLiteSemantic.constraint_input_8bit)
+
     def is_operator_semantic_valid(self, op):
         ext_type = optype_to_builtintype(op.type)
 
@@ -226,6 +231,9 @@
                 TFLiteSemantic.constraint_tens_no_dynamic,
                 TFLiteSemantic.constraint_tens_output_scalar,
             ],
+            Op.ArgMax: [
+                TFLiteSemantic.constraint_tens_quant_none_check,
+            ],
         }
         return generic_constraints_exclude_list
 
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py
index 26ccfeb..fd9a9c2 100644
--- a/ethosu/vela/tflite_supported_operators.py
+++ b/ethosu/vela/tflite_supported_operators.py
@@ -81,6 +81,8 @@
         | fc_vector_products
         # Mean (converts to depthwise conv)
         | set((Op.Mean,))
+        # ArgMax (converts to depthwise conv and maxpool)
+        | set((Op.ArgMax,))
     )
     unary_elem_wise_main_ops = Op.op_set(Op.is_unary_elementwise_op)
     binary_elem_wise_min_max_ops = set(
@@ -106,15 +108,7 @@
     elem_wise_main_ops = binary_elem_wise_main_ops | unary_elem_wise_main_ops
     pad_ops = set((Op.Pad,))
     supported_int32_tensor_ops = (
-        set(
-            (
-                Op.ReduceSum,
-                Op.CLZ,
-                Op.Shape,
-            )
-        )
-        | binary_elem_wise_add_mul_sub
-        | binary_elem_wise_shift_ops
+        set((Op.ReduceSum, Op.CLZ, Op.Shape, Op.ArgMax)) | binary_elem_wise_add_mul_sub | binary_elem_wise_shift_ops
     )
 
     relu_ops = set(
@@ -321,6 +315,11 @@
         # Reshape specific checks:
         self.specific_constraints[Op.Reshape].append(TFLiteSupportedOperators.constraint_reshape_shape_constant)
 
+        # ArgMax specific checks:
+        self.specific_constraints[Op.ArgMax].append(TFLiteSupportedOperators.constraint_argmax_input_dimensions)
+        self.specific_constraints[Op.ArgMax].append(TFLiteSupportedOperators.constraint_argmax_axis)
+        self.specific_constraints[Op.ArgMax].append(TFLiteSupportedOperators.constraint_argmax_depth)
+
     def is_operator_supported(self, op):
         ext_type = optype_to_builtintype(op.type)
         if op.type not in TFLiteSupportedOperators.supported_operators:
@@ -873,3 +872,25 @@
         extra = ", ".join(extra)
 
         return valid, f"Op has non-const input(s): {extra}"
+
+    @staticmethod
+    def constraint_argmax_axis(op):
+        "Operation must be performed along the depth axis"
+        inp_dims = len(op.inputs[0].shape)
+        axis = op.inputs[1].values
+        return (
+            axis in (3, -1),
+            f"Axis is {axis} and number of input dimensions is {inp_dims}",
+        )
+
+    @staticmethod
+    def constraint_argmax_input_dimensions(op):
+        "Number of input dimensions must be 4"
+        inp_dims = len(op.inputs[0].shape)
+        return inp_dims == 4, f"Number of input dimensions is {inp_dims}"
+
+    @staticmethod
+    def constraint_argmax_depth(op):
+        "IFM depth must be no greater than 127"
+        ifm_depth = op.inputs[0].shape[-1]
+        return ifm_depth <= 127, f"IFM depth is {ifm_depth}"