vela: Change Shape4D mutability usage

 - Removed requirement for cloning shapes when unique values required
   by forcing top-level immutability. This alleviates issues with Shapes
   being unintentionally shared and then mutated as if value-types.
 - Shape4D fields can no longer be assigned without replication.

Signed-off-by: Tim Hall <tim.hall@arm.com>
Change-Id: Ic0dbfa349eb0215eabefb4f4e2cf99f12d83699c
diff --git a/ethosu/vela/shape4d.py b/ethosu/vela/shape4d.py
index 8981e20..e26389a 100644
--- a/ethosu/vela/shape4d.py
+++ b/ethosu/vela/shape4d.py
@@ -15,66 +15,90 @@
 # limitations under the License.
 # Description:
 # Defines the class Shape4D.
+from collections import namedtuple
+
 from .numeric_util import full_shape
+from .numeric_util import round_up_divide
 
 
-class Shape4D:
+class Shape4D(namedtuple("Shape4D", ["batch", "height", "width", "depth"])):
     """
     4D Shape (in NHWC format)
     """
 
-    def __init__(self, shape, base=1):
-        assert shape is not None
-        assert len(shape) <= 4
-        self._shape4D = tuple(full_shape(4, shape, base))
+    def __new__(cls, n=1, h=1, w=1, c=1):
+        assert n is not None
+        if isinstance(n, list):
+            assert h == 1 and w == 1 and c == 1
+            tmp = full_shape(4, n, 1)
+            self = super(Shape4D, cls).__new__(cls, tmp[0], tmp[1], tmp[2], tmp[3])
+        else:
+            self = super(Shape4D, cls).__new__(cls, n, h, w, c)
+        return self
+
+    @classmethod
+    def from_list(cls, shape, base=1):
+        tmp = full_shape(4, shape, base)
+        return cls(tmp[0], tmp[1], tmp[2], tmp[3])
+
+    @classmethod
+    def from_hwc(cls, h, w, c):
+        return cls(1, h, w, c)
+
+    def with_batch(self, new_batch):
+        return Shape4D(new_batch, self.height, self.width, self.depth)
+
+    def with_height(self, new_height):
+        return Shape4D(self.batch, new_height, self.width, self.depth)
+
+    def with_width(self, new_width):
+        return Shape4D(self.batch, self.height, new_width, self.depth)
+
+    def with_hw(self, new_height, new_width):
+        return Shape4D(self.batch, new_height, new_width, self.depth)
+
+    def with_depth(self, new_depth):
+        return Shape4D(self.batch, self.height, self.width, new_depth)
+
+    def add(self, n, h, w, c):
+        return Shape4D(self.batch + n, self.height + h, self.width + w, self.depth + c)
+
+    def __add__(self, rhs):
+        return Shape4D(self.batch + rhs.batch, self.height + rhs.height, self.width + rhs.width, self.depth + rhs.depth)
+
+    def __sub__(self, rhs):
+        return Shape4D(self.batch - rhs.batch, self.height - rhs.height, self.width - rhs.width, self.depth - rhs.depth)
+
+    def __floordiv__(self, rhs):
+        return Shape4D(
+            self.batch // rhs.batch, self.height // rhs.height, self.width // rhs.width, self.depth // rhs.depth
+        )
+
+    def __mod__(self, rhs):
+        return Shape4D(self.batch % rhs.batch, self.height % rhs.height, self.width % rhs.width, self.depth % rhs.depth)
 
     def __str__(self):
-        return f"<Shape4D {self.as_list()}>"
+        return f"<Shape4D {list(self)}>"
 
-    def __eq__(self, other):
-        return self._shape4D == other._shape4D
+    def div_round_up(self, rhs):
+        return Shape4D(
+            round_up_divide(self.batch, rhs.batch),
+            round_up_divide(self.height, rhs.height),
+            round_up_divide(self.width, rhs.width),
+            round_up_divide(self.depth, rhs.depth),
+        )
 
-    def clone(self):
-        return Shape4D(self.as_list())
+    def elements(self):
+        return self.batch * self.width * self.height * self.depth
 
-    @property
-    def batch(self):
-        return self._shape4D[0]
+    def elements_wh(self):
+        return self.width * self.height
 
-    @property
-    def height(self):
-        return self._shape4D[1]
-
-    @property
-    def width(self):
-        return self._shape4D[2]
-
-    @property
-    def depth(self):
-        return self._shape4D[3]
-
-    @batch.setter
-    def batch(self, new_batch):
-        self._shape4D = (new_batch, self._shape4D[1], self._shape4D[2], self._shape4D[3])
-
-    @height.setter
-    def height(self, new_height):
-        self._shape4D = (self._shape4D[0], new_height, self._shape4D[2], self._shape4D[3])
-
-    @width.setter
-    def width(self, new_width):
-        self._shape4D = (self._shape4D[0], self._shape4D[1], new_width, self._shape4D[3])
-
-    @depth.setter
-    def depth(self, new_depth):
-        self._shape4D = (self._shape4D[0], self._shape4D[1], self._shape4D[2], new_depth)
-
-    def get_dim(self, dim):
-        assert -4 <= dim < 4
-        return self._shape4D[dim]
+    def is_empty(self):
+        return (self.batch + self.width + self.height + self.depth) == 0
 
     def as_list(self):
-        return list(self._shape4D)
+        return list(self)
 
     def get_hw_as_list(self):
         return list([self.height, self.width])