MLBEDSW-7437: Add 64-bit output support for ArgMax

- Added 64-bit support for ArgMax

- Updated constraints for ArgMax and regenerated SUPPORTED_OPS.md

Change-Id: I4ef7d2e6fccab0088b87757f6afe40a006c77bbd
Signed-off-by: Johan Alfven <johan.alfven@arm.com>
diff --git a/SUPPORTED_OPS.md b/SUPPORTED_OPS.md
index ba5b791..08c63e7 100644
--- a/SUPPORTED_OPS.md
+++ b/SUPPORTED_OPS.md
@@ -1,7 +1,7 @@
 # Supported Ops
 
 This file was automatically generated by Vela using the `--supported-ops-report` parameter.  
-Vela version: `3.7.1.dev10+g521c494`
+Vela version: `3.7.1.dev15+g2b5f66e`
 
 This file complies with
 [**Gitiles Markdown syntax**](https://github.com/google/gitiles/blob/master/Documentation/markdown.md)
@@ -71,7 +71,7 @@
 - Input(s), Output and Weight tensors with quantization scales must be finite
 - Input and Output tensors must have quantization scales that fit within float32 precision
 - Constant tensors should not have NoneType-values
-- Tensors must be of type: int16, int32, int8, uint8
+- Tensors must be of type: int16, int32, int8, uint8 - [ARG_MAX]
 - Tensors which are int32 are only valid when op type is: ADD, ARG_MAX, MUL, SHAPE, SUB
 - Tensor dimensions must be in the range [1, 65535]
 - Per-axis quantization is only supported for the following op types: CONV_2D, DEPTHWISE_CONV_2D, TRANSPOSE_CONV
@@ -101,6 +101,7 @@
 This is a list of constraints that the ARG_MAX operator must satisfy in order to be scheduled on the NPU.
 
 - IFM must be int8 or uint8
+- OFM must be int32 or int64
 - Operation must be performed along the depth axis
 - IFM depth must be no greater than 127
 
diff --git a/ethosu/vela/operation_util.py b/ethosu/vela/operation_util.py
index 21f9dbe..74836eb 100644
--- a/ethosu/vela/operation_util.py
+++ b/ethosu/vela/operation_util.py
@@ -19,6 +19,8 @@
 from typing import Optional
 from typing import Tuple
 
+import numpy as np
+
 from .data_type import DataType
 from .high_level_command_to_npu_op import ifm_ifm2_correct_order
 from .operation import ActivationFunction
@@ -26,6 +28,7 @@
 from .operation import Operation
 from .operation import Padding
 from .shape4d import Shape4D
+from .tensor import create_const_tensor
 from .tensor import QuantizationParameters
 from .tensor import Tensor
 
@@ -51,9 +54,16 @@
     return op
 
 
-def create_memcpy(name: str) -> Operation:
+def create_memcpy(
+    name: str,
+    ifm: Tensor,
+    ofm: Tensor,
+) -> Operation:
     op = Operation(Op.Memcpy, name)
     op.run_on_npu = True
+    op.add_input_tensor(ifm)
+    op.set_output_tensor(ofm)
+    op.set_ifm_ofm_shapes()
     return op
 
 
@@ -63,6 +73,50 @@
     return op
 
 
+def create_cast_op(
+    name: str,
+    ifm: Tensor,
+    ofm: Tensor,
+) -> Operation:
+    op = Operation(Op.DepthwiseConv2DBias, name)
+    op_attrs = {
+        "padding": Padding.VALID,
+        "stride_h": 1,
+        "stride_w": 1,
+        "strides": (1, 1, 1, 1),
+        "depth_multiplier": 1,
+        "channel_multiplier": 1,
+        "dilation_h_factor": 1,
+        "dilation_w_factor": 1,
+        "dilation": (1, 1, 1, 1),
+        "explicit_padding": None,
+    }
+    op.attrs.update(op_attrs)
+    op.add_input_tensor(ifm)
+
+    c = ifm.shape[-1]
+
+    shape = [1, 1, 1, c]
+    kernel = np.dstack([1] * c)
+    identity_quant = QuantizationParameters(scale_f32=1.0, zero_point=0)
+    op.add_input_tensor(
+        create_const_tensor(
+            op.name + "_weights",
+            shape,
+            DataType.uint8,
+            np.array(kernel).reshape(shape),
+            quantization=identity_quant,
+        ),
+    )
+    bias_values = [0] * c
+    dtype = DataType.int64 if op.ifm.dtype == DataType.int16 else DataType.int32
+    op.add_input_tensor(create_const_tensor(op.name + "_bias", [c], dtype, bias_values))
+    op.set_output_tensor(ofm)
+    op.set_ifm_ofm_shapes()
+
+    return op
+
+
 def create_depthwise_maxpool(
     name: str,
     ifm: Tensor,
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index 5b0e2fb..077f4af 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -51,7 +51,9 @@
 from .operation import Padding
 from .operation_util import create_add_nop
 from .operation_util import create_avgpool_nop
+from .operation_util import create_cast_op
 from .operation_util import create_depthwise_maxpool
+from .operation_util import create_memcpy
 from .operation_util import get_pad_values_from_input
 from .scaling import quantise_scale
 from .shape4d import Shape4D
@@ -520,7 +522,8 @@
             "dilation": (1, 1, 1, 1),
             "explicit_padding": None,
         }
-        op.name = "depthwise_conv_SHL_7"
+        orig_name = op.name
+        op.name = f"{orig_name}_depthwise_conv_SHL_7"
         op.type = Op.DepthwiseConv2DBias
         op.attrs.update(dw_op_attrs)
         n, h, w, c = full_shape(4, ifm.shape, 1)
@@ -592,25 +595,43 @@
             maxpool_op.write_offset = Shape4D([0, sum(op_heights[:op_idx]) * orig_ifm_shape.width, 0, 0])
             DebugDatabase.add_optimised(op, maxpool_op)
 
-        # Convert output to OFM dtype and reshape back to original OFM shape with 1x1 DWConv
-        dw_conv = Operation(Op.DepthwiseConv2DBias, f"depthwise_conv_convert_to_32bit_{op_idx}")
-        dw_conv.attrs.update(dw_op_attrs)
-        dw_conv.inputs = [maxpool_op.ofm]
-        dw_conv.add_input_tensor(
-            create_const_tensor(
-                "weights",
-                [1, 1, 1, 1],
-                DataType.uint8,
-                np.array([1]).reshape([1, 1, 1, 1]),
-                quantization=identity_quant,
-            ),
-        )
-        dw_conv.add_input_tensor(create_const_tensor(dw_conv.name + "_bias", [1], DataType.int64, [0]))
-        ofm.ops.append(dw_conv)
-        dw_conv.outputs = [ofm]
-        dw_conv.ifm_shapes.append(Shape4D([1, orig_ifm_shape.height, orig_ifm_shape.width, 1]))
-        dw_conv.ofm_shapes.append(Shape4D(ofm.shape))
-        DebugDatabase.add_optimised(op, dw_conv)
+        # Set final shape
+        maxpool_ofm.set_all_shapes([1, h, w, 1])
+
+        # Convert 16bit to 32bit or 64bit
+        if ofm.dtype == DataType.int64:
+            # If OFM dtype is int64 the result is converted by two cast ops (16bit to 32bit)
+            #
+            #   A     -> B         -> C          -> D (OFM)
+            #   |0001|   |00010000|   |0001|0000|   |00010000|00000000|
+            #    i16      i32          i16  i16      i32      i32
+            #                                       <-------i64------->
+            #
+            #   Memcpy is used to copy the content from B to C and from D to OFM
+            #   Memcpy will be turned into a nop or an DMA transer if memory regions differs.
+            intermediate_32bit = Tensor([1, h, w, 1], DataType.int32, f"{orig_name}_32bit")
+        else:
+            intermediate_32bit = ofm
+
+        op_cast = create_cast_op(f"{orig_name}_cast_to_32bit_1", maxpool_ofm, intermediate_32bit)
+        DebugDatabase.add_optimised(op, op_cast)
+
+        if ofm.dtype == DataType.int64:
+            # Create int16 tensor with double shape to cover the intermediate_32bit result from the first cast
+            intermediate_16bit_2x_size = Tensor([1, h, w, 2], DataType.int16, f"{orig_name}_16bit_2x_size")
+            memcpy_op = create_memcpy(f"{orig_name}_memcpy_1", intermediate_32bit, intermediate_16bit_2x_size)
+            DebugDatabase.add_optimised(op, memcpy_op)
+
+            # Create int32 tensor with double ofm shape to be able to store a "int64" result
+            intermediate_32bit_2x_size = Tensor([1, h, w, 2], DataType.int32, f"{orig_name}_32bit_2x_size")
+
+            op_cast = create_cast_op(
+                f"{orig_name}_cast_to_32bit_2", intermediate_16bit_2x_size, intermediate_32bit_2x_size
+            )
+            DebugDatabase.add_optimised(op, op_cast)
+
+            memcpy_op = create_memcpy("f{orig_name}_memcpy_2", intermediate_32bit_2x_size, ofm)
+            DebugDatabase.add_optimised(op, memcpy_op)
 
     return op
 
diff --git a/ethosu/vela/tflite_model_semantic.py b/ethosu/vela/tflite_model_semantic.py
index 495d71a..5661f36 100644
--- a/ethosu/vela/tflite_model_semantic.py
+++ b/ethosu/vela/tflite_model_semantic.py
@@ -191,6 +191,7 @@
 
         # ArgMax specific checks:
         self.specific_constraints[Op.ArgMax].append(TFLiteSemantic.constraint_input_8bit)
+        self.specific_constraints[Op.ArgMax].append(TFLiteSemantic.constraint_argmax_output)
 
     def is_operator_semantic_valid(self, op):
         ext_type = optype_to_builtintype(op.type)
@@ -634,6 +635,13 @@
         return valid, f"Op has ifm_dtype={ifm_dtype}"
 
     @staticmethod
+    def constraint_argmax_output(op):
+        "OFM must be int32 or int64"
+        ofm_dtype = op.ofm.dtype
+        valid = ofm_dtype in (DataType.int32, DataType.int64)
+        return valid, f"Op has ofm_dtype={ofm_dtype}"
+
+    @staticmethod
     def constraint_matching_either_shapes(op):
         "At least one Input's shape must match the OFM's shape"
         ifm_shape = op.ifm.shape
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py
index 66b9e94..25f19b7 100644
--- a/ethosu/vela/tflite_supported_operators.py
+++ b/ethosu/vela/tflite_supported_operators.py
@@ -206,6 +206,7 @@
 
         # Setup generic constraint exceptions
         self.generic_constraints_exceptions = defaultdict(list)
+        self.generic_constraints_exceptions[Op.ArgMax].append(TFLiteSupportedOperators.constraint_tens_dtype)
         self.generic_constraints_exceptions[Op.FullyConnected].append(TFLiteSupportedOperators.constraint_batch_size)
         self.generic_constraints_exceptions[Op.Softmax].append(TFLiteSupportedOperators.constraint_batch_size)
         self.generic_constraints_exceptions[Op.Reshape].append(TFLiteSupportedOperators.constraint_batch_size)