vela: Change Shape4D mutability usage

 - Removed requirement for cloning shapes when unique values required
   by forcing top-level immutability. This alleviates issues with Shapes
   being unintentionally shared and then mutated as if value-types.
 - Shape4D fields can no longer be assigned without replication.

Signed-off-by: Tim Hall <tim.hall@arm.com>
Change-Id: Ic0dbfa349eb0215eabefb4f4e2cf99f12d83699c
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index e84e11e..1e3b131 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -101,14 +101,14 @@
         new_op.outputs = [ofm]
         new_op.attrs["concat_axis"] = axis_4D
         new_op.attrs["concat_start"] = offset
-        offset += op.ifm_shapes[idx].get_dim(axis_4D)
+        offset += op.ifm_shapes[idx][axis_4D]
 
         new_op.attrs["concat_end"] = offset
         new_op.run_on_npu = True
         ofm.ops.append(new_op)
         DebugDatabase.add_optimised(op, new_op)
-        new_op.ifm_shapes.append(op.ifm_shapes[idx].clone())
-        new_op.ofm_shapes.append(op.ofm_shapes[0].clone())
+        new_op.ifm_shapes.append(op.ifm_shapes[idx])
+        new_op.ofm_shapes.append(op.ofm_shapes[0])
     assert ofm.shape[axis] == offset
 
     # If axis corresponds to C-dimension, NHCWB16 can only be used in the output if all the concat_start's are a
@@ -159,7 +159,7 @@
                     ofm_shape_idx = idx
                     break
 
-                offset_start[axis_4D] += split_op.ofm_shapes[idx].get_dim(axis_4D)
+                offset_start[axis_4D] += split_op.ofm_shapes[idx][axis_4D]
 
                 # If start offset is not a multiple of 16 in the C-dimension, NHCWB16 need to be avoided in the input
                 if (offset_start[-1] % 16) != 0:
@@ -171,7 +171,7 @@
         new_op.run_on_npu = True
         new_op.set_output_tensor(tens)
         new_op.ifm_shapes.append(Shape4D(inp.shape))
-        new_op.ofm_shapes.append(split_op.ofm_shapes[ofm_shape_idx].clone())
+        new_op.ofm_shapes.append(split_op.ofm_shapes[ofm_shape_idx])
         DebugDatabase.add_optimised(split_op, new_op)
 
     return tens
diff --git a/ethosu/vela/high_level_command_stream.py b/ethosu/vela/high_level_command_stream.py
index 9cbda45..c25c023 100644
--- a/ethosu/vela/high_level_command_stream.py
+++ b/ethosu/vela/high_level_command_stream.py
@@ -197,7 +197,7 @@
         self.pad_top = pad_top
         self.pad_bottom = pad_bottom
         for i in range(len(self.ofm_box.end_coord)):
-            assert self.ofm_box.end_coord[i] <= ps.ofm_shapes[0].get_dim(i)
+            assert self.ofm_box.end_coord[i] <= ps.ofm_shapes[0][i]
 
     def is_npu_pass_command(self):
         return True
diff --git a/ethosu/vela/npu_performance.py b/ethosu/vela/npu_performance.py
index 3acd5e6..5bba3b6 100644
--- a/ethosu/vela/npu_performance.py
+++ b/ethosu/vela/npu_performance.py
@@ -444,8 +444,8 @@
         npu_block_type = primary_op.type.npu_block_type
 
         ifm_tensor, _, weight_tensor, ofm_tensor = ps.get_primary_op_ifm_ifm2_weights_ofm()
-        ifm_tensor_shape = ps.primary_op.ifm_shapes[0].clone()
-        ofm_tensor_shape = ps.primary_op.ofm_shapes[0].clone()
+        ifm_tensor_shape = ps.primary_op.ifm_shapes[0]
+        ofm_tensor_shape = ps.primary_op.ofm_shapes[0]
         ofm_block.width = min(ofm_block.width, ofm_tensor_shape.width)
         ofm_block.height = min(ofm_block.height, ofm_tensor_shape.height)
         ofm_block.depth = min(ofm_block.depth, ofm_tensor_shape.depth)
@@ -480,9 +480,10 @@
 
             batch_size = ifm_tensor_shape.batch
 
-            # add in padding
-            ifm_tensor_shape.height += explicit_padding[0] + explicit_padding[2]  # height += top and bottom
-            ifm_tensor_shape.width += explicit_padding[1] + explicit_padding[3]  # width  += left and right
+            # add in padding, height += top and bottom, width  += left and right
+            ifm_tensor_shape = ifm_tensor_shape.add(
+                0, explicit_padding[0] + explicit_padding[2], explicit_padding[1] + explicit_padding[3], 0
+            )
 
             if npu_block_type != NpuBlockType.Pooling:
                 if npu_block_type == NpuBlockType.ReduceSum:
diff --git a/ethosu/vela/pass_packing.py b/ethosu/vela/pass_packing.py
index c973b9c..abd235f 100644
--- a/ethosu/vela/pass_packing.py
+++ b/ethosu/vela/pass_packing.py
@@ -231,7 +231,7 @@
                 ofm_tensor = op.ofm
                 if ofm_tensor is None:
                     ofm_tensor = op.outputs[0]
-                ofm_shape = op.ofm_shapes[0].clone() if op.run_on_npu else None
+                ofm_shape = op.ofm_shapes[0] if op.run_on_npu else None
 
                 build_pass((op,), ofm_tensor, ofm_shape)
 
diff --git a/ethosu/vela/shape4d.py b/ethosu/vela/shape4d.py
index 8981e20..e26389a 100644
--- a/ethosu/vela/shape4d.py
+++ b/ethosu/vela/shape4d.py
@@ -15,66 +15,90 @@
 # limitations under the License.
 # Description:
 # Defines the class Shape4D.
+from collections import namedtuple
+
 from .numeric_util import full_shape
+from .numeric_util import round_up_divide
 
 
-class Shape4D:
+class Shape4D(namedtuple("Shape4D", ["batch", "height", "width", "depth"])):
     """
     4D Shape (in NHWC format)
     """
 
-    def __init__(self, shape, base=1):
-        assert shape is not None
-        assert len(shape) <= 4
-        self._shape4D = tuple(full_shape(4, shape, base))
+    def __new__(cls, n=1, h=1, w=1, c=1):
+        assert n is not None
+        if isinstance(n, list):
+            assert h == 1 and w == 1 and c == 1
+            tmp = full_shape(4, n, 1)
+            self = super(Shape4D, cls).__new__(cls, tmp[0], tmp[1], tmp[2], tmp[3])
+        else:
+            self = super(Shape4D, cls).__new__(cls, n, h, w, c)
+        return self
+
+    @classmethod
+    def from_list(cls, shape, base=1):
+        tmp = full_shape(4, shape, base)
+        return cls(tmp[0], tmp[1], tmp[2], tmp[3])
+
+    @classmethod
+    def from_hwc(cls, h, w, c):
+        return cls(1, h, w, c)
+
+    def with_batch(self, new_batch):
+        return Shape4D(new_batch, self.height, self.width, self.depth)
+
+    def with_height(self, new_height):
+        return Shape4D(self.batch, new_height, self.width, self.depth)
+
+    def with_width(self, new_width):
+        return Shape4D(self.batch, self.height, new_width, self.depth)
+
+    def with_hw(self, new_height, new_width):
+        return Shape4D(self.batch, new_height, new_width, self.depth)
+
+    def with_depth(self, new_depth):
+        return Shape4D(self.batch, self.height, self.width, new_depth)
+
+    def add(self, n, h, w, c):
+        return Shape4D(self.batch + n, self.height + h, self.width + w, self.depth + c)
+
+    def __add__(self, rhs):
+        return Shape4D(self.batch + rhs.batch, self.height + rhs.height, self.width + rhs.width, self.depth + rhs.depth)
+
+    def __sub__(self, rhs):
+        return Shape4D(self.batch - rhs.batch, self.height - rhs.height, self.width - rhs.width, self.depth - rhs.depth)
+
+    def __floordiv__(self, rhs):
+        return Shape4D(
+            self.batch // rhs.batch, self.height // rhs.height, self.width // rhs.width, self.depth // rhs.depth
+        )
+
+    def __mod__(self, rhs):
+        return Shape4D(self.batch % rhs.batch, self.height % rhs.height, self.width % rhs.width, self.depth % rhs.depth)
 
     def __str__(self):
-        return f"<Shape4D {self.as_list()}>"
+        return f"<Shape4D {list(self)}>"
 
-    def __eq__(self, other):
-        return self._shape4D == other._shape4D
+    def div_round_up(self, rhs):
+        return Shape4D(
+            round_up_divide(self.batch, rhs.batch),
+            round_up_divide(self.height, rhs.height),
+            round_up_divide(self.width, rhs.width),
+            round_up_divide(self.depth, rhs.depth),
+        )
 
-    def clone(self):
-        return Shape4D(self.as_list())
+    def elements(self):
+        return self.batch * self.width * self.height * self.depth
 
-    @property
-    def batch(self):
-        return self._shape4D[0]
+    def elements_wh(self):
+        return self.width * self.height
 
-    @property
-    def height(self):
-        return self._shape4D[1]
-
-    @property
-    def width(self):
-        return self._shape4D[2]
-
-    @property
-    def depth(self):
-        return self._shape4D[3]
-
-    @batch.setter
-    def batch(self, new_batch):
-        self._shape4D = (new_batch, self._shape4D[1], self._shape4D[2], self._shape4D[3])
-
-    @height.setter
-    def height(self, new_height):
-        self._shape4D = (self._shape4D[0], new_height, self._shape4D[2], self._shape4D[3])
-
-    @width.setter
-    def width(self, new_width):
-        self._shape4D = (self._shape4D[0], self._shape4D[1], new_width, self._shape4D[3])
-
-    @depth.setter
-    def depth(self, new_depth):
-        self._shape4D = (self._shape4D[0], self._shape4D[1], self._shape4D[2], new_depth)
-
-    def get_dim(self, dim):
-        assert -4 <= dim < 4
-        return self._shape4D[dim]
+    def is_empty(self):
+        return (self.batch + self.width + self.height + self.depth) == 0
 
     def as_list(self):
-        return list(self._shape4D)
+        return list(self)
 
     def get_hw_as_list(self):
         return list([self.height, self.width])
diff --git a/ethosu/vela/softmax.py b/ethosu/vela/softmax.py
index c3b0611..4418f01 100644
--- a/ethosu/vela/softmax.py
+++ b/ethosu/vela/softmax.py
@@ -216,10 +216,9 @@
         # Reshape ifm/ofm (if needed)
         ifm_shape = self.op.ifm_shapes[0]
         if ifm_shape.batch > 1:
-            ifm_shape.height = ifm_shape.batch * ifm_shape.height
-            ifm_shape.batch = 1
+            self.op.ifm_shapes[0] = ifm_shape.with_height(ifm_shape.batch * ifm_shape.height).with_batch(1)
             self.op.ifm.avoid_NHCWB16 = True
-            self.op.ofm_shapes[0] = ifm_shape.clone()
+            self.op.ofm_shapes[0] = self.op.ifm_shapes[0]
             self.op.ofm.avoid_NHCWB16 = True
 
         if ifm.dtype in (DataType.uint8, DataType.int8) and ofm.dtype == ifm.dtype: