MLBEDSW-6927: Add ofm_stride_multiplier attribute to operation

Allow sparse writing of OFM by multiplying H/W/C of the OFM with the
values of ofm_stride_multiplier

Signed-off-by: Rickard Bolin <rickard.bolin@arm.com>
Change-Id: I65d742ad36ad3154e9914cdd22e2da928ad1f095
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py
index 65473b8..9997031 100644
--- a/ethosu/vela/tensor.py
+++ b/ethosu/vela/tensor.py
@@ -592,7 +592,9 @@
         rounding_quantum = full_shape(4, list(self.storage_rounding_quantum), 1)
         return Shape4D(shape_round_to_quantum(op_shape4D.as_list(), rounding_quantum))
 
-    def addresses_for_rolling_buffer(self, start_coord: Shape, end_coord: Shape, op_shape4D: Shape4D) -> Tuple:
+    def addresses_for_rolling_buffer(
+        self, start_coord: Shape, end_coord: Shape, strides: List[int], op_shape4D: Shape4D
+    ) -> Tuple:
         # returns ( box_height0, box_height1, box_width, [address_tl, address_tr, address_bl, address_br] )
 
         if self.storage_shape == []:
@@ -600,7 +602,7 @@
                 1,
                 1,
                 1,
-                [self.address_for_coordinate(start_coord, op_shape4D=op_shape4D), 0, 0, 0],
+                [self.address_for_coordinate(start_coord, strides, op_shape4D), 0, 0, 0],
             )
 
         if self.is_standard_fm:
@@ -618,67 +620,28 @@
         box_width = crossing_x - start_coord[2]
 
         addresses: List = [0] * 4
-        addresses[0] = self.address_for_coordinate(start_coord, op_shape4D=op_shape4D)
+        addresses[0] = self.address_for_coordinate(start_coord, strides, op_shape4D)
 
         if end_coord[2] > crossing_x:
             addresses[1] = self.address_for_coordinate(
-                [start_coord[0], start_coord[1], crossing_x, start_coord[3]], op_shape4D=op_shape4D
+                [start_coord[0], start_coord[1], crossing_x, start_coord[3]], strides, op_shape4D
             )
             raise UnsupportedFeatureError("Striping in vertical direction is not supported")
         if end_coord[1] > crossing_y:
             addresses[2] = self.address_for_coordinate(
-                [start_coord[0], crossing_y, start_coord[2], start_coord[3]], op_shape4D=op_shape4D
+                [start_coord[0], crossing_y, start_coord[2], start_coord[3]], strides, op_shape4D
             )
         if end_coord[1] > crossing_y and end_coord[2] > crossing_x:
             addresses[3] = self.address_for_coordinate(
-                [start_coord[0], crossing_y, crossing_x, start_coord[3]], op_shape4D=op_shape4D
+                [start_coord[0], crossing_y, crossing_x, start_coord[3]], strides, op_shape4D
             )
 
         return box_height0, box_height0, box_width, addresses
 
-    def address_for_coordinate(self, coord: Shape, is_top_box: bool = False, op_shape4D: Shape4D = None) -> int:
-        offset = self.address_offset_for_coordinate(coord, op_shape4D=op_shape4D, is_top_box=is_top_box)
-        assert offset is not None
-        return self.address + offset
+    def get_strides(self, shape4D: Optional[Shape4D]) -> List[int]:
 
-    def get_strides_and_coord(
-        self, coord: Optional[Shape] = None, shape4D: Optional[Shape4D] = None
-    ) -> Tuple[Optional[Shape], Optional[Shape]]:
-        if coord is None:
-            coord = [0] * min(len(self.storage_shape), 4)
-
-        if shape4D and self.is_standard_fm:
-            augmented_shape = self.get_4D_storage_shape_for_shape(shape4D).as_list()
-        else:
-            augmented_shape = full_shape(4, self.storage_shape, 1)
-
-        augmented_coord = coord
-
-        while len(augmented_coord) < 4:
-            augmented_coord = [0] + augmented_coord
-
-        assert len(augmented_coord) == len(augmented_shape)
-
-        if self.format == TensorFormat.NHWC:
-            augmented_shape = [augmented_shape[0], augmented_shape[3]] + augmented_shape[1:3] + [1]
-            augmented_coord = [augmented_coord[0], augmented_coord[3]] + augmented_coord[1:3] + [0]
-
-        elif self.format == TensorFormat.NHCWB16:
-            channel_divisor = 16
-            augmented_shape = augmented_shape[0:4] + [1]
-            augmented_coord = (
-                [augmented_coord[0], augmented_coord[3] // channel_divisor]
-                + augmented_coord[1:3]
-                + [augmented_coord[3] % channel_divisor]
-            )
-
-            if augmented_shape[1] == 0:
-                augmented_shape[1] = 1
-
-        else:
-            assert self.format in (TensorFormat.Unknown, TensorFormat.WeightsCompressed)
-            return None, None
-
+        augmented_shape = self.get_augmented_shape(shape4D)
+        assert len(augmented_shape) == 5
         strides: List = [0] * len(augmented_shape)
         stride = self.element_size() * self.storage_compression_scale
 
@@ -695,13 +658,53 @@
             strides[2] = augmented_shape[2] * augmented_shape[3] * stride  # STRIDE_Y
             strides[0] = strides[2] * augmented_shape[1]  # STRIDE_N
 
-        return strides, augmented_coord
-
-    def get_strides(self, shape4D: Optional[Shape4D] = None) -> Shape:
-        strides, _ = self.get_strides_and_coord(shape4D=shape4D)
-        assert strides is not None
         return strides
 
+    def get_augmented_shape(self, shape4D: Optional[Shape4D] = None) -> Optional[Shape]:
+
+        if shape4D and self.is_standard_fm:
+            augmented_shape = self.get_4D_storage_shape_for_shape(shape4D).as_list()
+        else:
+            augmented_shape = full_shape(4, self.storage_shape, 1)
+
+        if self.format == TensorFormat.NHWC:
+            augmented_shape = [augmented_shape[0], augmented_shape[3]] + augmented_shape[1:3] + [1]
+
+        elif self.format == TensorFormat.NHCWB16:
+            augmented_shape = augmented_shape[0:4] + [1]
+
+            if augmented_shape[1] == 0:
+                augmented_shape[1] = 1
+
+        else:
+            assert self.format in (TensorFormat.Unknown, TensorFormat.WeightsCompressed)
+            return None
+
+        return augmented_shape
+
+    def get_augmented_coord(self, coord: Optional[Shape] = None) -> Optional[Shape]:
+        if coord is None:
+            coord = [0] * min(len(self.storage_shape), 4)
+
+        missing_len = 4 - len(coord)
+        augmented_coord = ([0] * missing_len) + coord
+
+        if self.format == TensorFormat.NHWC:
+            augmented_coord = [augmented_coord[0], augmented_coord[3]] + augmented_coord[1:3] + [0]
+
+        elif self.format == TensorFormat.NHCWB16:
+            channel_divisor = 16
+            augmented_coord = (
+                [augmented_coord[0], augmented_coord[3] // channel_divisor]
+                + augmented_coord[1:3]
+                + [augmented_coord[3] % channel_divisor]
+            )
+        else:
+            assert self.format in (TensorFormat.Unknown, TensorFormat.WeightsCompressed)
+            return None
+
+        return augmented_coord
+
     def find_npu_op(self) -> Optional[Operation]:
         # Returns the NPU operator that uses this tensor
         for op in self.consumers():
@@ -743,8 +746,12 @@
         assert 0 <= index < len(self.compressed_values)
         return index == len(self.compressed_values) - 1
 
-    def address_offset_for_coordinate(
-        self, orig_coord: Shape, op_shape4D: Optional[Shape4D] = None, is_top_box: bool = False
+    def address_for_coordinate(
+        self,
+        orig_coord: Shape,
+        strides: Optional[List[int]] = None,
+        op_shape4D: Optional[Shape4D] = None,
+        is_top_box: bool = False,
     ) -> Optional[int]:
         address_offset = 0
         assert self.purpose != TensorPurpose.Weights
@@ -771,18 +778,22 @@
         # handle wraparound for partial buffers. make sure to do this after subtracting top box:
         coord = [c % storage_shape[idx] for idx, c in enumerate(coord)]
 
-        strides, augmented_coord = self.get_strides_and_coord(coord, op_shape4D)
-        if strides is None:
-            return None
+        # Strides may be passed as an argument, for example when creating feature maps as the strides may be modified
+        # by the "ofm_stride_multiplier" operation attribute. If not, they are calculated here.
+        if not strides:
+            strides = self.get_strides(op_shape4D)
 
         if is_top_box:
             address_offset += 1 * strides[-1]  # one element
 
+        augmented_coord = self.get_augmented_coord(coord)
+        assert augmented_coord is not None
+
         address_offset += np.dot(augmented_coord, strides)
 
         assert address_offset >= 0
         assert address_offset <= storage_size
-        return address_offset
+        return self.address + address_offset
 
     def is_allocated_in_tensor_arena(self, scratch_tensor_mem_area: MemArea) -> bool:
         return (self.mem_area == scratch_tensor_mem_area) and (self.mem_type in (MemType.Scratch, MemType.Scratch_fast))