TOSA: Elementwise Rank > 4 and Batch > 1

Added support for elementwise operations:
 -Support for up to Rank == 6
 -Support for Batch > 1 for Rank == 4
 -For binary elementwise ops this includes handling
  of broadcasting in dimensions above H-dimension

Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com>
Change-Id: I73850bbfb288077a99bd2ceecbf989172016da24
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index 1558b94..b426792 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -545,6 +545,7 @@
         res.rounding_mode = self.rounding_mode
         res.explicit_scaling = self.explicit_scaling
         res.low_precision_scaling = self.low_precision_scaling
+        res.rescale = self.rescale
 
         return res
 
diff --git a/ethosu/vela/shape4d.py b/ethosu/vela/shape4d.py
index fd67403..08b2a6a 100644
--- a/ethosu/vela/shape4d.py
+++ b/ethosu/vela/shape4d.py
@@ -136,6 +136,9 @@
     def elements(self):
         return self.batch * self.width * self.height * self.depth
 
+    def dot_prod(self, rhs):
+        return self.batch * rhs.batch + self.width * rhs.width + self.height * rhs.height + self.depth * rhs.depth
+
     def elements_wh(self):
         return self.width * self.height
 
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py
index 37fd06e..2e70d72 100644
--- a/ethosu/vela/tensor.py
+++ b/ethosu/vela/tensor.py
@@ -632,7 +632,7 @@
         self, coord: Optional[Shape] = None, shape4D: Optional[Shape4D] = None
     ) -> Tuple[Optional[Shape], Optional[Shape]]:
         if coord is None:
-            coord = [0] * len(self.storage_shape)
+            coord = [0] * min(len(self.storage_shape), 4)
 
         if shape4D and self.is_standard_fm:
             augmented_shape = self.get_4D_storage_shape_for_shape(shape4D).as_list()
diff --git a/ethosu/vela/tosa_graph_optimiser.py b/ethosu/vela/tosa_graph_optimiser.py
index 1ef0444..f4aa453 100644
--- a/ethosu/vela/tosa_graph_optimiser.py
+++ b/ethosu/vela/tosa_graph_optimiser.py
@@ -503,6 +503,142 @@
     return convert_to_lut(op, table.values, "table")
 
 
+def create_elem_part_op(op, ifm_offset_n, ifm2_offset_n, ofm_offset_n):
+    part_op = op.clone()
+    offset = Shape4D(0, 0, 0, 0)
+
+    part_op.read_offsets[0] = offset.with_batch(ifm_offset_n)
+    part_op.read_shapes[0] = op.ifm_shapes[0].with_batch(1)
+    part_op.write_offset = offset.with_batch(ofm_offset_n)
+    part_op.write_shape = op.ofm_shapes[0].with_batch(1)
+    part_op.ifm_shapes = op.ifm_shapes.copy()
+    part_op.ofm_shapes = op.ofm_shapes.copy()
+    part_op.ifm.consumer_list.append(part_op)
+    op.ofm.ops.append(part_op)
+    if ifm2_offset_n:
+        part_op.read_offsets[1] = offset.with_batch(ifm2_offset_n)
+        part_op.read_shapes[1] = op.ifm_shapes[1].with_batch(1)
+        part_op.ifm2.consumer_list.append(part_op)
+
+
+def get_nhwc_stride(shape):
+    stride_x = shape.depth
+    stride_y = shape.width * stride_x
+    stride_n = shape.height * stride_y
+    return Shape4D(stride_n, stride_y, stride_x, 1)
+
+
+def decomp_unary_elementwise(op):
+    """
+    Decompose binary elementwise ops with Rank > 3 (H,W,D).
+    If Rank > 3, all the dimensions above H are viewed as the N dimension.
+    the elementwise operation will be decomposed to N (of ofm) elementwise operations.
+    By reading and writing with offsets from/to the ifm/ofm.
+    """
+    ifm = op.ifm
+    ofm = op.ofm
+    assert op.type.is_unary_elementwise_op()
+    assert None not in (ifm, ofm)
+    assert ifm.shape == ofm.shape
+
+    rank = len(ofm.shape)
+    if rank > 3:
+        n = rank - 3
+        ofm_decomp_shape = Shape4D(ofm.shape[0:n])
+        new_ofm_shape = [ofm_decomp_shape.elements()] + ofm.shape[n:]
+        op.ifm_shapes.append(Shape4D(new_ofm_shape))
+        op.ofm_shapes.append(Shape4D(new_ofm_shape))
+
+        if new_ofm_shape[0] == 1:
+            return
+
+        for height in range(ofm_decomp_shape.height):
+            for width in range(ofm_decomp_shape.width):
+                for depth in range(ofm_decomp_shape.depth):
+                    ifm_offset, ofm_offset = Shape4D(0, height, width, depth)
+                    create_elem_part_op(op, ifm_offset, None, ofm_offset)
+
+        ifm.consumer_list.remove(op)
+        ofm.ops.remove(op)
+    return
+
+
+def decomp_binary_elementwise(op):
+    """
+    Decompose binary elementwise ops with Rank > 3 (H,W,D).
+    If Rank > 3, all the dimensions above H are viewed as the N dimension.
+    the elementwise operation will be decomposed to N (of ofm) elementwise operations.
+    By reading and writing with offsets from/to the ifm(s)/ofm.
+    Note: Broadcast need to be handled, and TOSA allowes for broadcast by both ifm and ifm2
+    """
+
+    ifm = op.ifm
+    ifm2 = op.ifm2
+    ofm = op.ofm
+    assert op.type.is_binary_elementwise_op()
+    assert None not in (ifm, ifm2, ofm)
+
+    rank = len(ofm.shape)
+    if rank > 3:
+        n = rank - 3
+        ofm_decomp_shape = Shape4D(ofm.shape[0:n])
+        ifm_decomp_shape = Shape4D(ifm.shape[0:n])
+        ifm2_decomp_shape = Shape4D(ifm2.shape[0:n])
+
+        ofm_decomp_stride = get_nhwc_stride(ofm_decomp_shape)
+        ifm_decomp_stride = get_nhwc_stride(ifm_decomp_shape)
+        ifm2_decomp_stride = get_nhwc_stride(ifm2_decomp_shape)
+
+        new_ofm_shape = [ofm_decomp_shape.elements()] + ofm.shape[n:]
+        new_ifm_shape = [ifm_decomp_shape.elements()] + ifm.shape[n:]
+        new_ifm2_shape = [ifm2_decomp_shape.elements()] + ifm2.shape[n:]
+
+        op.ofm_shapes.append(Shape4D(new_ofm_shape))
+        op.ifm_shapes.append(Shape4D(new_ifm_shape))
+        op.ifm_shapes.append(Shape4D(new_ifm2_shape))
+
+        if new_ifm_shape[0] == new_ifm2_shape[0] == new_ofm_shape[0] == 1:
+            return
+
+        for height in range(ofm_decomp_shape.height):
+            for width in range(ofm_decomp_shape.width):
+                for depth in range(ofm_decomp_shape.depth):
+                    ofm_offset = Shape4D(0, height, width, depth)
+
+                    ifm_d = depth if ifm_decomp_shape.depth == ofm_decomp_shape.depth else 0
+                    ifm_w = width if ifm_decomp_shape.width == ofm_decomp_shape.width else 0
+                    ifm_h = height if ifm_decomp_shape.height == ofm_decomp_shape.height else 0
+                    ifm_offset = Shape4D(0, ifm_h, ifm_w, ifm_d)
+
+                    ifm2_d = depth if ifm2_decomp_shape.depth == ofm_decomp_shape.depth else 0
+                    ifm2_w = width if ifm2_decomp_shape.width == ofm_decomp_shape.width else 0
+                    ifm2_h = height if ifm2_decomp_shape.height == ofm_decomp_shape.height else 0
+                    ifm2_offset = Shape4D(0, ifm2_h, ifm2_w, ifm2_d)
+
+                    ofm_offset_n = ofm_offset.dot_prod(ofm_decomp_stride)
+                    ifm_offset_n = ifm_offset.dot_prod(ifm_decomp_stride)
+                    ifm2_offset_n = ifm2_offset.dot_prod(ifm2_decomp_stride)
+                    create_elem_part_op(op, ifm_offset_n, ifm2_offset_n, ofm_offset_n)
+
+        ifm.consumer_list.remove(op)
+        ifm2.consumer_list.remove(op)
+        ofm.ops.remove(op)
+    return
+
+
+def decomp_elementwise(tens, arch, nng):
+    """
+    Decompose elementwise ops with Rank > 3 (H,W,D).
+    """
+    assert len(tens.ops) == 1
+
+    if tens.ops[0].type.is_binary_elementwise_op():
+        decomp_binary_elementwise(tens.ops[0])
+    elif tens.ops[0].type.is_unary_elementwise_op():
+        decomp_unary_elementwise(tens.ops[0])
+    return tens
+
+
 def fixup_quantization(op, arch, nng):
     if op.ifm and op.ifm.quantization.zero_point is None:
         op.ifm.quantization.zero_point = 0
@@ -521,6 +657,13 @@
 
 
 def tosa_optimise_graph(nng, arch):
+
+    # Decomposing to 4 dimensions
+    for idx, sg in enumerate(nng.subgraphs):
+        nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
+            nng, sg, arch, [decomp_elementwise], [], rewrite_unsupported=False
+        )
+
     # Pre-processing step
     pre_process_list = [
         supported_operator_check,
diff --git a/ethosu/vela/tosa_supported_operators.py b/ethosu/vela/tosa_supported_operators.py
index 98df27e..f5eddcc 100644
--- a/ethosu/vela/tosa_supported_operators.py
+++ b/ethosu/vela/tosa_supported_operators.py
@@ -40,15 +40,15 @@
     mac_main_ops = convolution_like_ops | pooling_ops | fc_vector_products
     memory_only_ops = set((Op.Reshape, Op.Transpose, Op.Concat, Op.SplitSliceRead,))
     binary_elem_wise_add_mul_sub = set((Op.Add, Op.Mul, Op.RescaleMul, Op.Sub,))
+    elem_wise_ops = binary_elem_wise_add_mul_sub
     type_conversion_ops = set((Op.Rescale,))
     relu_ops = set((Op.Clamp, Op.ReluN,))
     activation_ops = relu_ops | set((Op.Table,))
     pad_ops = set((Op.Pad,))
 
     npu_post_ops = activation_ops
-    supported_operators = (
-        mac_main_ops | type_conversion_ops | npu_post_ops | memory_only_ops | binary_elem_wise_add_mul_sub | pad_ops
-    )
+
+    supported_operators = mac_main_ops | type_conversion_ops | npu_post_ops | memory_only_ops | elem_wise_ops | pad_ops
 
     # Supported data types
     # TODO will differ compared to TensorFlow Lite, currently set to the same
@@ -132,35 +132,37 @@
         return valid, ", ".join(extra)
 
     # TODO This is for a HW limitation, that is to be resolved in SW later on
-    @staticmethod
-    def constraint_rank(op):
-        "Tensor rank must be <= 4"
+    @classmethod
+    def constraint_rank(self, op):
+        "Tensor rank must be <= 4, if not elementwise"
         valid = True
         extra = []
-        tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
-        if not tensors:
-            tensors = [tens for tens in op.inputs if tens]
-        for tens in tensors:
-            rank = len(tens.shape)
-            if not rank <= 4:
-                valid = False
-                extra.append(f"Tensor '{tens.name}' has rank: {rank}")
+        if op.type not in self.binary_elem_wise_add_mul_sub:
+            tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
+            if not tensors:
+                tensors = [tens for tens in op.inputs if tens]
+            for tens in tensors:
+                rank = len(tens.shape)
+                if not rank <= 4:
+                    valid = False
+                    extra.append(f"Tensor '{tens.name}' has rank: {rank}")
         return valid, ", ".join(extra)
 
     # TODO This is for a HW limitation, that is to be resolved in SW later on
-    @staticmethod
-    def constraint_batch(op):
-        "If Tensor rank is 4 batch of ifms/ofm must be 1"
+    @classmethod
+    def constraint_batch(self, op):
+        "If Tensor rank is 4 batch of ifms/ofm must be 1, if not elementwise"
         valid = True
         extra = []
-        tensors = [tens for tens in op.get_ifm_ifm2_ofm() if tens]
-        if not tensors:
-            tensors = [tens for tens in op.inputs if tens]
-        for tens in tensors:
-            rank = len(tens.shape)
-            if rank == 4 and tens.shape[0] != 1:
-                valid = False
-                extra.append(f"Tensor '{tens.name}' has rank: 4 and N: {tens.shape[0]}")
+        if op.type not in self.binary_elem_wise_add_mul_sub:
+            tensors = [tens for tens in op.get_ifm_ifm2_ofm() if tens]
+            if not tensors:
+                tensors = [tens for tens in op.inputs if tens]
+            for tens in tensors:
+                rank = len(tens.shape)
+                if rank == 4 and tens.shape[0] != 1:
+                    valid = False
+                    extra.append(f"Tensor '{tens.name}' has rank: 4 and N: {tens.shape[0]}")
         return valid, ", ".join(extra)
 
     @staticmethod