TOSA: Decompose elem op tensors

Added decomposition of tensors exceeding
maximum size supported by NPU.

Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com>
Change-Id: I17a99cb72947d2f1064a631ad6975ce895c258d5
diff --git a/ethosu/vela/shape4d.py b/ethosu/vela/shape4d.py
index 08b2a6a..fd1ee94 100644
--- a/ethosu/vela/shape4d.py
+++ b/ethosu/vela/shape4d.py
@@ -111,6 +111,9 @@
     def __sub__(self, rhs):
         return Shape4D(self.batch - rhs.batch, self.height - rhs.height, self.width - rhs.width, self.depth - rhs.depth)
 
+    def floordiv_const(self, const):
+        return Shape4D(self.batch // const, self.height // const, self.width // const, self.depth // const)
+
     def __floordiv__(self, rhs):
         return Shape4D(
             self.batch // rhs.batch, self.height // rhs.height, self.width // rhs.width, self.depth // rhs.depth
diff --git a/ethosu/vela/tosa_graph_optimiser.py b/ethosu/vela/tosa_graph_optimiser.py
index f4aa453..1e059cc 100644
--- a/ethosu/vela/tosa_graph_optimiser.py
+++ b/ethosu/vela/tosa_graph_optimiser.py
@@ -503,21 +503,71 @@
     return convert_to_lut(op, table.values, "table")
 
 
-def create_elem_part_op(op, ifm_offset_n, ifm2_offset_n, ofm_offset_n):
-    part_op = op.clone()
-    offset = Shape4D(0, 0, 0, 0)
+def decompose_tensors_hwc(op):
+    max_t_size = 65535
+    ofm_shape = op.ofm_shapes[0]
+    ifm_shape = op.ifm_shapes[0]
+    ifm2_shape = op.ifm_shapes[1] if op.ifm_shapes[1] else None
 
-    part_op.read_offsets[0] = offset.with_batch(ifm_offset_n)
-    part_op.read_shapes[0] = op.ifm_shapes[0].with_batch(1)
-    part_op.write_offset = offset.with_batch(ofm_offset_n)
-    part_op.write_shape = op.ofm_shapes[0].with_batch(1)
+    limit_shape = Shape4D(1, max_t_size, max_t_size, max_t_size)
+
+    if any(dim_size > max_t_size for dim_size in ofm_shape.as_list()):
+        ofm_split = ofm_shape.floordiv_const(max_t_size).add(1, 1, 1, 1)
+
+        for height in range(ofm_split.height):
+            for width in range(ofm_split.width):
+                for depth in range(ofm_split.depth):
+                    ofm_offset = Shape4D(0, height * max_t_size, width * max_t_size, depth * max_t_size)
+                    ofm_part_shape = ofm_shape.clip(ofm_offset, limit_shape)
+                    ofm_cut = (ofm_offset, ofm_part_shape)
+
+                    ifm_d = depth * max_t_size if ifm_shape.depth == ofm_shape.depth else 0
+                    ifm_w = width * max_t_size if ifm_shape.width == ofm_shape.width else 0
+                    ifm_h = height * max_t_size if ifm_shape.height == ofm_shape.height else 0
+                    ifm_offset = Shape4D(0, ifm_h, ifm_w, ifm_d)
+                    ifm_part_shape = ifm_shape.clip(ifm_offset, limit_shape)
+                    ifm_cut = (ifm_offset, ifm_part_shape)
+
+                    if ifm2_shape is not None:
+                        ifm2_d = depth * max_t_size if ifm2_shape.depth == ofm_shape.depth else 0
+                        ifm2_w = width * max_t_size if ifm2_shape.width == ofm_shape.width else 0
+                        ifm2_h = height * max_t_size if ifm2_shape.height == ofm_shape.height else 0
+                        ifm2_offset = Shape4D(0, ifm2_h, ifm2_w, ifm2_d)
+                        ifm2_part_shape = ifm2_shape.clip(ifm2_offset, limit_shape)
+                        ifm2_cut = (ifm2_offset, ifm2_part_shape)
+                    else:
+                        ifm2_offset = None
+                        ifm2_cut = (None, None)
+
+                    create_elem_part_op(op, ifm_cut, ifm2_cut, ofm_cut)
+        op.ofm.ops.remove(op)
+        op.ifm.consumer_list.remove(op)
+        if op.ifm2 is not None:
+            op.ifm2.consumer_list.remove(op)
+    return
+
+
+def create_elem_part_op(op, ifm_cut, ifm2_cut, ofm_cut):
+    part_op = op.clone()
+    ifm_read_offset = op.read_offsets[0] if op.read_offsets[0] is not None else Shape4D(0, 0, 0, 0)
+    ofm_write_offset = op.write_offset if op.write_offset is not None else Shape4D(0, 0, 0, 0)
+    ifm_offset, ifm_shape = ifm_cut
+    ofm_offset, ofm_shape = ofm_cut
+
+    part_op.read_offsets[0] = ifm_read_offset + ifm_offset
+    part_op.read_shapes[0] = ifm_shape
+    part_op.write_offset = ofm_write_offset + ofm_offset
+    part_op.write_shape = ofm_shape
     part_op.ifm_shapes = op.ifm_shapes.copy()
     part_op.ofm_shapes = op.ofm_shapes.copy()
     part_op.ifm.consumer_list.append(part_op)
     op.ofm.ops.append(part_op)
-    if ifm2_offset_n:
-        part_op.read_offsets[1] = offset.with_batch(ifm2_offset_n)
-        part_op.read_shapes[1] = op.ifm_shapes[1].with_batch(1)
+
+    ifm2_offset, ifm2_shape = ifm2_cut
+    if ifm2_offset:
+        ifm2_read_offset = op.read_offsets[1] if op.read_offsets[1] is not None else Shape4D(0, 0, 0, 0)
+        part_op.read_offsets[1] = ifm2_read_offset + ifm2_offset
+        part_op.read_shapes[1] = ifm2_shape
         part_op.ifm2.consumer_list.append(part_op)
 
 
@@ -528,114 +578,120 @@
     return Shape4D(stride_n, stride_y, stride_x, 1)
 
 
-def decomp_unary_elementwise(op):
+def get_elem_shapes_removed_singles(op):
     """
-    Decompose binary elementwise ops with Rank > 3 (H,W,D).
-    If Rank > 3, all the dimensions above H are viewed as the N dimension.
-    the elementwise operation will be decomposed to N (of ofm) elementwise operations.
-    By reading and writing with offsets from/to the ifm/ofm.
+    Returns the shapes of ifm(s)/ofms after removing all the dimensions that are 1 for all ifm(s)/ofm
     """
-    ifm = op.ifm
-    ofm = op.ofm
-    assert op.type.is_unary_elementwise_op()
-    assert None not in (ifm, ofm)
-    assert ifm.shape == ofm.shape
+    rank = len(op.ofm.shape)
+    binary = op.ifm2 is not None
+    new_ofm_shape = []
+    new_ifm_shape = []
+    new_ifm2_shape = []
+    for idx in range(rank):
+        if op.ofm.shape[idx] != 1:
+            new_ofm_shape.append(op.ofm.shape[idx])
+            new_ifm_shape.append(op.ifm.shape[idx])
+            if binary:
+                new_ifm2_shape.append(op.ifm2.shape[idx])
+    if new_ofm_shape == []:
+        new_ofm_shape = [1]
+        new_ifm_shape = [1]
+        new_ifm2_shape = [1] if binary else None
 
-    rank = len(ofm.shape)
-    if rank > 3:
-        n = rank - 3
-        ofm_decomp_shape = Shape4D(ofm.shape[0:n])
-        new_ofm_shape = [ofm_decomp_shape.elements()] + ofm.shape[n:]
-        op.ifm_shapes.append(Shape4D(new_ofm_shape))
-        op.ofm_shapes.append(Shape4D(new_ofm_shape))
-
-        if new_ofm_shape[0] == 1:
-            return
-
-        for height in range(ofm_decomp_shape.height):
-            for width in range(ofm_decomp_shape.width):
-                for depth in range(ofm_decomp_shape.depth):
-                    ifm_offset, ofm_offset = Shape4D(0, height, width, depth)
-                    create_elem_part_op(op, ifm_offset, None, ofm_offset)
-
-        ifm.consumer_list.remove(op)
-        ofm.ops.remove(op)
-    return
+    return new_ofm_shape, new_ifm_shape, new_ifm2_shape
 
 
-def decomp_binary_elementwise(op):
+def decomp_dims_elementwise(op):
     """
-    Decompose binary elementwise ops with Rank > 3 (H,W,D).
+    Decompose elementwise ops with Rank > 3 (H,W,D).
     If Rank > 3, all the dimensions above H are viewed as the N dimension.
     the elementwise operation will be decomposed to N (of ofm) elementwise operations.
     By reading and writing with offsets from/to the ifm(s)/ofm.
-    Note: Broadcast need to be handled, and TOSA allowes for broadcast by both ifm and ifm2
+    Note: Broadcast need to be handled for binary elementwise ops, and TOSA allowes for broadcast by both ifm and ifm2
     """
 
     ifm = op.ifm
     ifm2 = op.ifm2
     ofm = op.ofm
-    assert op.type.is_binary_elementwise_op()
-    assert None not in (ifm, ifm2, ofm)
+    binary = op.ifm2 is not None
+    assert len(ofm.shape) <= 6
 
-    rank = len(ofm.shape)
+    # Remove dimensions that are all 1
+    new_ofm_shape, new_ifm_shape, new_ifm2_shape = get_elem_shapes_removed_singles(op)
+    rank = len(new_ofm_shape)
+
     if rank > 3:
         n = rank - 3
-        ofm_decomp_shape = Shape4D(ofm.shape[0:n])
-        ifm_decomp_shape = Shape4D(ifm.shape[0:n])
-        ifm2_decomp_shape = Shape4D(ifm2.shape[0:n])
-
+        ofm_decomp_shape = Shape4D(new_ofm_shape[0:n])
         ofm_decomp_stride = get_nhwc_stride(ofm_decomp_shape)
-        ifm_decomp_stride = get_nhwc_stride(ifm_decomp_shape)
-        ifm2_decomp_stride = get_nhwc_stride(ifm2_decomp_shape)
+        ofm_part_shape = Shape4D(new_ofm_shape[n:])
+        op.ofm_shapes.append(Shape4D([ofm_decomp_shape.elements()] + new_ofm_shape[n:]))
 
-        new_ofm_shape = [ofm_decomp_shape.elements()] + ofm.shape[n:]
-        new_ifm_shape = [ifm_decomp_shape.elements()] + ifm.shape[n:]
-        new_ifm2_shape = [ifm2_decomp_shape.elements()] + ifm2.shape[n:]
+        if binary:
+            ifm_decomp_shape = Shape4D(new_ifm_shape[0:n])
+            ifm2_decomp_shape = Shape4D(new_ifm2_shape[0:n])
+            ifm_decomp_stride = get_nhwc_stride(ifm_decomp_shape)
+            ifm2_decomp_stride = get_nhwc_stride(ifm2_decomp_shape)
+            ifm_part_shape = Shape4D(new_ifm_shape[n:])
+            ifm2_part_shape = Shape4D(new_ifm2_shape[n:])
+            op.ifm_shapes.append(Shape4D([ifm_decomp_shape.elements()] + new_ifm_shape[n:]))
+            op.ifm_shapes.append(Shape4D([ifm2_decomp_shape.elements()] + new_ifm2_shape[n:]))
+        else:
+            op.ifm_shapes.append(Shape4D([ofm_decomp_shape.elements()] + new_ofm_shape[n:]))
 
-        op.ofm_shapes.append(Shape4D(new_ofm_shape))
-        op.ifm_shapes.append(Shape4D(new_ifm_shape))
-        op.ifm_shapes.append(Shape4D(new_ifm2_shape))
-
-        if new_ifm_shape[0] == new_ifm2_shape[0] == new_ofm_shape[0] == 1:
-            return
-
+        op_list = []
         for height in range(ofm_decomp_shape.height):
             for width in range(ofm_decomp_shape.width):
                 for depth in range(ofm_decomp_shape.depth):
                     ofm_offset = Shape4D(0, height, width, depth)
+                    ofm_offset = Shape4D(ofm_offset.dot_prod(ofm_decomp_stride), 0, 0, 0)
+                    ofm_cut = (ofm_offset, ofm_part_shape)
 
-                    ifm_d = depth if ifm_decomp_shape.depth == ofm_decomp_shape.depth else 0
-                    ifm_w = width if ifm_decomp_shape.width == ofm_decomp_shape.width else 0
-                    ifm_h = height if ifm_decomp_shape.height == ofm_decomp_shape.height else 0
-                    ifm_offset = Shape4D(0, ifm_h, ifm_w, ifm_d)
+                    if binary:
+                        ifm_d = depth if ifm_decomp_shape.depth == ofm_decomp_shape.depth else 0
+                        ifm_w = width if ifm_decomp_shape.width == ofm_decomp_shape.width else 0
+                        ifm_h = height if ifm_decomp_shape.height == ofm_decomp_shape.height else 0
+                        ifm_offset = Shape4D(0, ifm_h, ifm_w, ifm_d)
+                        ifm_offset = Shape4D(ifm_offset.dot_prod(ifm_decomp_stride), 0, 0, 0)
+                        ifm_cut = (ifm_offset, ifm_part_shape)
 
-                    ifm2_d = depth if ifm2_decomp_shape.depth == ofm_decomp_shape.depth else 0
-                    ifm2_w = width if ifm2_decomp_shape.width == ofm_decomp_shape.width else 0
-                    ifm2_h = height if ifm2_decomp_shape.height == ofm_decomp_shape.height else 0
-                    ifm2_offset = Shape4D(0, ifm2_h, ifm2_w, ifm2_d)
+                        ifm2_d = depth if ifm2_decomp_shape.depth == ofm_decomp_shape.depth else 0
+                        ifm2_w = width if ifm2_decomp_shape.width == ofm_decomp_shape.width else 0
+                        ifm2_h = height if ifm2_decomp_shape.height == ofm_decomp_shape.height else 0
+                        ifm2_offset = Shape4D(0, ifm2_h, ifm2_w, ifm2_d)
+                        ifm2_offset = Shape4D(ifm2_offset.dot_prod(ifm2_decomp_stride), 0, 0, 0)
+                        ifm2_cut = (ifm2_offset, ifm2_part_shape)
+                        op_list.append(create_elem_part_op(op, ifm_cut, ifm2_cut, ofm_cut))
+                    else:
+                        op_list.append(create_elem_part_op(op, ofm_cut, None, ofm_cut))
 
-                    ofm_offset_n = ofm_offset.dot_prod(ofm_decomp_stride)
-                    ifm_offset_n = ifm_offset.dot_prod(ifm_decomp_stride)
-                    ifm2_offset_n = ifm2_offset.dot_prod(ifm2_decomp_stride)
-                    create_elem_part_op(op, ifm_offset_n, ifm2_offset_n, ofm_offset_n)
-
-        ifm.consumer_list.remove(op)
-        ifm2.consumer_list.remove(op)
         ofm.ops.remove(op)
-    return
+        ifm.consumer_list.remove(op)
+        if binary:
+            ifm2.consumer_list.remove(op)
+    else:
+        op.ofm_shapes.append(Shape4D(new_ofm_shape))
+        op.ifm_shapes.append(Shape4D(new_ifm_shape))
+        op.ifm_shapes.append(Shape4D(new_ifm2_shape))
+
+    return [op]
 
 
 def decomp_elementwise(tens, arch, nng):
     """
     Decompose elementwise ops with Rank > 3 (H,W,D).
+    Decompose size of tensors exceeding NPU max size
     """
-    assert len(tens.ops) == 1
+    if len(tens.ops) == 1 and tens.ops[0].type.is_elementwise_op():
+        op = tens.ops[0]
+        rank = len(op.ofm.shape)
+        assert rank <= 6
 
-    if tens.ops[0].type.is_binary_elementwise_op():
-        decomp_binary_elementwise(tens.ops[0])
-    elif tens.ops[0].type.is_unary_elementwise_op():
-        decomp_unary_elementwise(tens.ops[0])
+        decomp_list = []
+        decomp_list = decomp_dims_elementwise(op)
+
+        for part_op in decomp_list:
+            decompose_tensors_hwc(part_op)
     return tens
 
 
diff --git a/ethosu/vela/tosa_supported_operators.py b/ethosu/vela/tosa_supported_operators.py
index f5eddcc..1012a61 100644
--- a/ethosu/vela/tosa_supported_operators.py
+++ b/ethosu/vela/tosa_supported_operators.py
@@ -117,18 +117,19 @@
     # This is for a HW limitation, that is to be resolved in SW later on
     @classmethod
     @docstring_format_args(tens_dim_range)
-    def constraint_tens_dimension(cls, op):
-        "Tensor dimensions must be in the range [{}, {}]"
-        tens_min, tens_max = cls.tens_dim_range
+    def constraint_tens_dimension(self, op):
+        "Tensor dimensions must be in the range [{}, {}], if not elementwise"
+        tens_min, tens_max = self.tens_dim_range
         valid = True
         extra = []
-        tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
-        if not tensors:
-            tensors = [tens for tens in op.inputs if tens]
-        for tens in tensors:
-            if not all(tens_min <= dim <= tens_max for dim in tens.shape):
-                valid = False
-                extra.append(f"Tensor '{tens.name}' has shape: {tens.shape}")
+        if op.type not in self.binary_elem_wise_add_mul_sub:
+            tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
+            if not tensors:
+                tensors = [tens for tens in op.inputs if tens]
+            for tens in tensors:
+                if not all(tens_min <= dim <= tens_max for dim in tens.shape):
+                    valid = False
+                    extra.append(f"Tensor '{tens.name}' has shape: {tens.shape}")
         return valid, ", ".join(extra)
 
     # TODO This is for a HW limitation, that is to be resolved in SW later on