MLBEDSW-4034: New Scheduler Size or Performance Optimisation

 - Merged dev/scheduler at 83639f90e8c828f70de6e29142355a940224959b

Signed-off-by: Tim Hall <tim.hall@arm.com>
Change-Id: I0050529d4b42da93768c7264296434dd877fb5b4
diff --git a/ethosu/vela/shape4d.py b/ethosu/vela/shape4d.py
index 5d849d9..fd67403 100644
--- a/ethosu/vela/shape4d.py
+++ b/ethosu/vela/shape4d.py
@@ -16,8 +16,10 @@
 # Description:
 # Defines the class Shape4D.
 from collections import namedtuple
+from enum import Enum
 
 from .numeric_util import full_shape
+from .numeric_util import round_up
 from .numeric_util import round_up_divide
 
 
@@ -42,6 +44,27 @@
         return cls(tmp[0], tmp[1], tmp[2], tmp[3])
 
     @classmethod
+    def min(cls, lhs, rhs):
+        return Shape4D(
+            min(lhs.batch, rhs.batch), min(lhs.height, rhs.height), min(lhs.width, rhs.width), min(lhs.depth, rhs.depth)
+        )
+
+    @classmethod
+    def max(cls, lhs, rhs):
+        return Shape4D(
+            max(lhs.batch, rhs.batch), max(lhs.height, rhs.height), max(lhs.width, rhs.width), max(lhs.depth, rhs.depth)
+        )
+
+    @classmethod
+    def round_up(cls, lhs, rhs):
+        return Shape4D(
+            round_up(lhs.batch, rhs.batch),
+            round_up(lhs.height, rhs.height),
+            round_up(lhs.width, rhs.width),
+            round_up(lhs.depth, rhs.depth),
+        )
+
+    @classmethod
     def from_hwc(cls, h, w, c):
         return cls(1, h, w, c)
 
@@ -60,6 +83,25 @@
     def with_depth(self, new_depth):
         return Shape4D(self.batch, self.height, self.width, new_depth)
 
+    def with_axis(self, axis, new_val):
+        shape_as_list = self.as_list()
+        shape_as_list[axis] = new_val
+        return Shape4D.from_list(shape_as_list)
+
+    @staticmethod
+    def _clip_len(pos, length, size):
+        if pos < 0:
+            length = length + pos
+            pos = 0
+        return min(pos + length, size) - pos
+
+    def clip(self, offset, sub_shape):
+        n = Shape4D._clip_len(offset.batch, sub_shape.batch, self.batch)
+        h = Shape4D._clip_len(offset.height, sub_shape.height, self.height)
+        w = Shape4D._clip_len(offset.width, sub_shape.width, self.width)
+        c = Shape4D._clip_len(offset.depth, sub_shape.depth, self.depth)
+        return Shape4D(n, h, w, c)
+
     def add(self, n, h, w, c):
         return Shape4D(self.batch + n, self.height + h, self.width + w, self.depth + c)
 
@@ -74,6 +116,9 @@
             self.batch // rhs.batch, self.height // rhs.height, self.width // rhs.width, self.depth // rhs.depth
         )
 
+    def __truediv__(self, rhs):
+        return Shape4D(self.batch / rhs.batch, self.height / rhs.height, self.width / rhs.width, self.depth / rhs.depth)
+
     def __mod__(self, rhs):
         return Shape4D(self.batch % rhs.batch, self.height % rhs.height, self.width % rhs.width, self.depth % rhs.depth)
 
@@ -102,3 +147,52 @@
 
     def get_hw_as_list(self):
         return list([self.height, self.width])
+
+
+class VolumeIterator:
+    """
+    4D Volume iterator. Use to traverse 4D tensor volumes in smaller shapes.
+    """
+
+    class Direction(Enum):
+        CWHN = 0
+
+    def __init__(
+        self,
+        shape: Shape4D,
+        sub_shape: Shape4D,
+        start: Shape4D = Shape4D(0, 0, 0, 0),
+        delta: Shape4D = None,
+        dir=Direction.CWHN,
+    ):
+        self.b = start.batch
+        self.y = start.height
+        self.x = start.width
+        self.z = start.depth
+        self.shape = shape
+        self.sub_shape = sub_shape
+        self.delta = sub_shape if delta is None else delta
+        assert self.delta.elements() > 0, "Iterator will not move"
+
+    def __iter__(self):
+        return self
+
+    def __next__(self):
+        if self.b >= self.shape.batch:
+            raise StopIteration()
+
+        offset = Shape4D(self.b, self.y, self.x, self.z)
+
+        # CWHN
+        self.z += self.delta.depth
+        if self.z >= self.shape.depth:
+            self.z = 0
+            self.x += self.delta.width
+            if self.x >= self.shape.width:
+                self.x = 0
+                self.y += self.delta.height
+                if self.y >= self.shape.height:
+                    self.y = 0
+                    self.b += self.delta.batch
+
+        return offset, self.shape.clip(offset, self.sub_shape)