MLBEDSW-7756: MLCE: Grouped convolutions runtime problem

 - Added graph optimiser function to convert convolution groups into
a split followed by separate convolutions and then a concat
 - Added semantic check for convolution groups
 - Added unit tests for convolution groups semantic checks
 - Fixed a minor typing issue with test_constraint_stride_range

Change-Id: I78ade408aa23469a79c9f517c4751da8619b77a9
Signed-off-by: Tim Hall <tim.hall@arm.com>
diff --git a/SUPPORTED_OPS.md b/SUPPORTED_OPS.md
index 7b46b8b..8a992b5 100644
--- a/SUPPORTED_OPS.md
+++ b/SUPPORTED_OPS.md
@@ -19,7 +19,7 @@
 # Supported Ops
 
 This file was automatically generated by Vela using the `--supported-ops-report` parameter.  
-Vela version: `3.8.1.dev14+gefc7d21e.d20230707`
+Vela version: `3.8.1.dev17+g1d5e859.d20230711`
 
 This file complies with
 [**Gitiles Markdown syntax**](https://github.com/google/gitiles/blob/master/Documentation/markdown.md)
@@ -155,6 +155,8 @@
 This is a list of constraints that the CONV_2D operator must satisfy in order to be scheduled on the NPU.
 
 - Stride values for both width and height must be integer types
+- IFM depth must be a whole multiple of the filter kernel depth
+- Number of filter kernels must be equally divisible by the number of convolution groups
 - Dilation factor values for both width and height must be integer types
 - Stride width must be greater than or equal to 1.  
         For stride widths greater than 3, the post-optimization stride needs to be less than or equal to 3.  
diff --git a/ethosu/vela/driver_actions.py b/ethosu/vela/driver_actions.py
index 2e4412c..a711146 100644
--- a/ethosu/vela/driver_actions.py
+++ b/ethosu/vela/driver_actions.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2020-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
 #
 # SPDX-License-Identifier: Apache-2.0
 #
@@ -79,7 +79,7 @@
     if arch.is_ethos_u65_system:
         n.set_product(1)
     else:
-        n.set_product(0)  # U55
+        n.set_product(0)  # Ethos-U55
     n.set_shram_size(shram_size)
     n.set_cmd_stream_version(0)  # may be incremented in the future
     n.set_macs_per_cc(log2_macs_cc)
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index 998d94f..31839c7 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -852,6 +852,11 @@
             return self.forced_input_quantization
         return self.ifm.quantization
 
+    def add_output_tensor(self, tens):
+        self.outputs.append(tens)
+        if self not in tens.ops:
+            tens.ops.append(self)
+
     def set_output_tensor(self, tens):
         tens.ops = [self]
         self.outputs = [tens]
diff --git a/ethosu/vela/test/test_tflite_model_semantic.py b/ethosu/vela/test/test_tflite_model_semantic.py
index 7a82d2c..e7fd307 100644
--- a/ethosu/vela/test/test_tflite_model_semantic.py
+++ b/ethosu/vela/test/test_tflite_model_semantic.py
@@ -121,14 +121,32 @@
 
 def test_constraint_stride_type():
     # Stride width and height must be integer types
-    op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8])
+    op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8], weights_shape=[1, 1, 1, 1])
     op.attrs = {"stride_w": 1.5, "stride_h": "1"}
     assert not semantic_checker.is_operator_semantic_valid(op)
 
 
+def test_constraint_conv_groups_ifm_depth():
+    # Test IFM depth is a whole multiple of the filter kernel depth
+    op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 15], [1, 8, 8, 5], weights_shape=[1, 1, 3, 5])
+    assert semantic_checker.is_operator_semantic_valid(op)
+
+    op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 15], [1, 8, 8, 5], weights_shape=[1, 1, 4, 5])
+    assert not semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_constraint_conv_groups_num_filters():
+    # Test number of filter kernels is equally divisible by the number of convolution groups
+    op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 15], [1, 8, 8, 20], weights_shape=[1, 1, 3, 20])
+    assert semantic_checker.is_operator_semantic_valid(op)
+
+    op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 15], [1, 8, 8, 21], weights_shape=[1, 1, 3, 21])
+    assert not semantic_checker.is_operator_semantic_valid(op)
+
+
 def test_constraint_dilation_type():
     # Dilation width and height must be integer types
-    op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8])
+    op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8], weights_shape=[1, 1, 1, 1])
     op.attrs = {"stride_w": 1, "stride_h": 1, "dilation_w_factor": 1.5, "dilation_h_factor": "1"}
     assert not semantic_checker.is_operator_semantic_valid(op)
 
diff --git a/ethosu/vela/test/test_tflite_supported_operators.py b/ethosu/vela/test/test_tflite_supported_operators.py
index f2ad858..f54211f 100644
--- a/ethosu/vela/test/test_tflite_supported_operators.py
+++ b/ethosu/vela/test/test_tflite_supported_operators.py
@@ -16,6 +16,8 @@
 #
 # Description:
 # Unit tests for tflite support_operators
+from typing import List
+
 import numpy as np
 import pytest
 
@@ -121,7 +123,7 @@
         [[1, 8, 40, 8], 8, 1, True],
     ],
 )
-def test_constraint_stride_range(ifm_shape: list[int], stride_w: int, stride_h: int, supported: bool):
+def test_constraint_stride_range(ifm_shape: List[int], stride_w: int, stride_h: int, supported: bool):
     # Stride width and height must lie within a certain range
     op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, ifm_shape, [1, 8, 8, 8], [1, 1, 1, 1])
     op.attrs = {"stride_w": stride_w, "stride_h": stride_h}
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index 31d3ae1..c7fe6cd 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -2423,6 +2423,130 @@
     return op
 
 
+def convert_conv_groups(op: Operation, arch, nng):
+    """
+    Convert convolution groups to a split followed by separate convolutions and then a concat.
+    This needs to run before the concat and split handling functions"""
+    if not op.type.is_conv2d_op():
+        return op
+
+    num_conv_groups = op.attrs.get("num_conv_groups", 0)
+    if num_conv_groups > 1:
+        # convolution groups params
+        ifm_depth_cg = op.ifm.shape[-1] // num_conv_groups
+        num_filters_cg = op.weights.shape[-1] // num_conv_groups
+
+        # create split
+        split_op = Operation(Op.Split, f"{op.name}_split")
+        split_op.attrs.update(
+            {
+                "num_splits": num_conv_groups,
+            }
+        )
+        # first input is the split axis
+        split_op.add_input_tensor(
+            # split along the depth axis
+            create_const_tensor(f"{split_op.name}_axis", [0], DataType.int32, [-1])
+        )
+        # second input is the ifm
+        split_op.add_input_tensor(op.ifm)
+        # calculate shape of each ofm part
+        split_op_ofm_shape = op.ifm.shape[:-1] + [ifm_depth_cg]
+
+        # create concat. do this prior to each conv group so that the for-loop can reference the concat as it iterates
+        concat_op = Operation(Op.ConcatTFLite, f"{op.name}_concat")
+        concat_op.attrs.update(
+            {
+                "axis": -1,
+                "fused_activation_function": None,
+            }
+        )
+        # calculate shape of each ifm part
+        concat_op_ifm_shape = op.ofm.shape[:-1] + [num_filters_cg]
+        # output is the concatenated tensor
+        concat_op.set_output_tensor(op.ofm)  # will disconnect ofm from op
+
+        # for each conv group
+        for i in range(num_conv_groups):
+            # cg params
+            cg_oc_start = i * num_filters_cg
+            cg_oc_end = (i + 1) * num_filters_cg
+
+            # split has multiple outputs
+            split_op_ofm_part = Tensor(split_op_ofm_shape, op.ifm.dtype, f"{split_op.name}_out{i}")
+            split_op_ofm_part.quantization = op.ifm.quantization.clone()
+            split_op.add_output_tensor(split_op_ofm_part)
+
+            # concat has multiple inputs
+            concat_op_ifm_part = Tensor(concat_op_ifm_shape, op.ifm.dtype, f"{concat_op.name}_in{i}")
+            concat_op_ifm_part.quantization = op.ofm.quantization.clone()
+            concat_op.add_input_tensor(concat_op_ifm_part)
+
+            # create convolution group operator
+            conv_group_op = Operation(op.type, f"{op.name}_cg{i}")
+            conv_group_op.attrs = op.attrs.copy()
+            conv_group_op.attrs["num_conv_groups"] = 1
+            # first input is the ifm
+            conv_group_op.add_input_tensor(split_op_ofm_part)
+            # second input is weights. the number of filters (i.e. the output channels) need to be split equally
+            # across all of the convolution groups
+            conv_group_op_weights_shape = op.weights.shape[:-1] + [num_filters_cg]
+            conv_group_op_weights_quant = op.weights.quantization.clone()
+            conv_group_op_weights_quant.scale_f32 = op.weights.quantization.scale_f32[..., cg_oc_start:cg_oc_end]
+            conv_group_op_weights_quant.zero_point = op.weights.quantization.zero_point[..., cg_oc_start:cg_oc_end]
+            conv_group_op.add_input_tensor(
+                create_const_tensor(
+                    f"{op.weights.name}_cg{i}",
+                    conv_group_op_weights_shape,
+                    op.weights.dtype,
+                    op.weights.values[..., cg_oc_start:cg_oc_end],
+                    op.weights.purpose,
+                    conv_group_op_weights_quant,
+                )
+            )
+            # third input is bias. like the weights, the bias needs to be split equally across all of the convolution
+            # groups
+            if op.bias is None:
+                conv_group_op.add_input_tensor(None)
+            else:
+                conv_group_op_bias_shape = op.bias.shape[:-1] + [num_filters_cg]
+                conv_group_op_bias_quant = op.bias.quantization.clone()
+                conv_group_op_bias_quant.scale_f32 = op.bias.quantization.scale_f32[..., cg_oc_start:cg_oc_end]
+                conv_group_op_bias_quant.zero_point = op.bias.quantization.zero_point[..., cg_oc_start:cg_oc_end]
+                conv_group_op.add_input_tensor(
+                    create_const_tensor(
+                        f"{op.bias.name}_cg{i}",
+                        conv_group_op_bias_shape,
+                        op.bias.dtype,
+                        op.bias.values[..., cg_oc_start:cg_oc_end],
+                        op.bias.purpose,
+                        op.bias.quantization,
+                    )
+                )
+            # output goes to the concat
+            conv_group_op.set_output_tensor(concat_op_ifm_part)
+            # update the cg op shapes and debug db
+            conv_group_op.set_ifm_ofm_shapes()
+            DebugDatabase.add_optimised(op, conv_group_op)
+
+        # update the split/concat op shapes/debug db
+        split_op.set_ifm_ofm_shapes()
+        DebugDatabase.add_optimised(op, split_op)
+        concat_op.set_ifm_ofm_shapes()
+        DebugDatabase.add_optimised(op, concat_op)
+
+        # disconnect the original convolution operator.
+        # the ofm has already been disconnected by concat_op.set_output_tensor()
+        op.ifm.consumer_list.remove(op)
+        op.inputs = []
+        op.outputs = []
+
+        # return last op so that other graph optimiser functions can process the new operators
+        op = concat_op
+
+    return op
+
+
 def supported_operator_check(op, arch, nng):
     op.run_on_npu = arch.tflite_supported_operators.is_operator_supported(op)
     return op
@@ -2447,7 +2571,7 @@
         )
 
     # Pre-processing step
-    pre_process_list = [supported_operator_check, set_ifm_ofm_op_shapes, fixup_reshape]
+    pre_process_list = [supported_operator_check, set_ifm_ofm_op_shapes, fixup_reshape, convert_conv_groups]
 
     for idx, sg in enumerate(nng.subgraphs):
         nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
diff --git a/ethosu/vela/tflite_model_semantic.py b/ethosu/vela/tflite_model_semantic.py
index 3ac78b2..6264891 100644
--- a/ethosu/vela/tflite_model_semantic.py
+++ b/ethosu/vela/tflite_model_semantic.py
@@ -110,6 +110,10 @@
         # Conv-like checks:
         for op_type in TFLiteSemantic.convolution_like_ops:
             self.specific_constraints[op_type].append(TFLiteSemantic.constraint_stride_type)
+            if op_type in TFLiteSemantic.convolution_ops:
+                # Only Conv has groups
+                self.specific_constraints[op_type].append(TFLiteSemantic.constraint_conv_groups_ifm_depth)
+                self.specific_constraints[op_type].append(TFLiteSemantic.constraint_conv_groups_num_filters)
             if op_type not in TFLiteSemantic.transpose_convolution_ops:
                 # Transpose Conv does not contain dilation
                 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_dilation_type)
@@ -373,6 +377,36 @@
         return valid, f"Op has stride WxH as: {repr(w)}x{repr(h)}"
 
     @staticmethod
+    def constraint_conv_groups_ifm_depth(op):
+        """IFM depth must be a whole multiple of the filter kernel depth"""
+        ifm_depth = op.ifm.shape[-1]  # nhwc
+        kernel_ic = op.weights.shape[-2]  # hwio
+        num_conv_groups = ifm_depth // kernel_ic
+
+        if ifm_depth % kernel_ic == 0:
+            op.attrs["num_conv_groups"] = num_conv_groups
+            valid = True
+        else:
+            valid = False
+
+        return valid, f"IFM depth = {ifm_depth} and filter kernel depth = {kernel_ic}"
+
+    @staticmethod
+    def constraint_conv_groups_num_filters(op):
+        """Number of filter kernels must be equally divisible by the number of convolution groups"""
+        ifm_depth = op.ifm.shape[-1]  # nhwc
+        kernel_ic = op.weights.shape[-2]  # hwio
+        kernel_oc = op.weights.shape[-1]  # hwio
+        num_conv_groups = ifm_depth // kernel_ic
+
+        if kernel_oc % num_conv_groups == 0:
+            valid = True
+        else:
+            valid = False
+
+        return valid, f"Filter kernels = {kernel_oc} and convolution groups = {num_conv_groups}"
+
+    @staticmethod
     def constraint_dilation_type(op):
         "Dilation factor values for both width and height must be integer types"
         w, h = op.get_kernel_dilation()