MLBEDSW-4334 Non-linear format decision in graph opt.

Check if non linear tensor format can be used is
refactored.

-Flag avoid_NHCWB16 replaced with needs_linear_format
-Checking restrictions located to one function in graph optimiser.

Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com>
Change-Id: Iec5c7996a1a6039cad052197f1ae56f7c0290440
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index 56932db..dd540a7 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -104,8 +104,6 @@
 
         for idx, inp in enumerate(op.inputs):
             op.ifm_shapes[idx] = Shape4D(desired_shape)
-            if Shape4D(inp.shape) != op.ifm_shapes[idx]:
-                inp.avoid_NHCWB16 = True
         op.type = Op.PackReshaped
 
     inputs, axis = op.get_concat_inputs_axis()
@@ -125,12 +123,7 @@
         offset = concat_end
     assert ofm.shape[axis] == offset
 
-    # If axis corresponds to C-dimension, NHCWB16 can only be used in the output if all the concat_start's are a
-    # multiple of 16. This as, it is only then the address offset for the ofm, for all operations, will be 16 byte
-    # aligned. For other values of axis the address offsets will be 16 byte aligned, as they are all based on c = 0
-    # and those addresses are always 16 byte aligned due to the NHCWB16 format.
-    if axis == -1 or axis == (len(ofm.shape) - 1):
-        ofm.avoid_NHCWB16 = any(op2.write_offset.depth % 16 != 0 for op2 in ofm.ops if op2.write_offset is not None)
+    return op
 
 
 def rewrite_split_ops(tens, arch, nng):
@@ -171,10 +164,6 @@
 
                 offset_start[axis_4D] += split_op.ofm_shapes[idx][axis_4D]
 
-                # If start offset is not a multiple of 16 in the C-dimension, NHCWB16 need to be avoided in the input
-                if (offset_start[-1] % 16) != 0:
-                    inp.avoid_NHCWB16 = True
-
         new_op.read_offsets[0] = Shape4D.from_list(offset_start, 0)
         new_op.run_on_npu = True
         new_op.set_output_tensor(tens)
@@ -224,6 +213,108 @@
             DebugDatabase.add_optimised(op, avgpool_op)
 
 
+def avoid_nhcwb16_for_concat(tens):
+    # If axis corresponds to C-dimension, NHCWB16 can only be used in the output if all the concat_start's are a
+    # multiple of 16. This as, it is only then the address offset for the ofm, for all operations, will be 16 byte
+    # aligned. For other values of axis the address offsets will be 16 byte aligned, as they are all based on c = 0
+    # and those addresses are always 16 byte aligned due to the NHCWB16 format.
+    return any(op.write_offset.depth % 16 != 0 for op in tens.ops if op.write_offset is not None)
+
+
+def avoid_nhcwb16_for_split(tens):
+    # If read offset is not a multiple of 16 in the C-dimension, NHCWB16 need to be avoided in the input
+    for cons_op in tens.consumer_list:
+        if cons_op.ifm == tens:
+            read_offset = cons_op.read_offsets[0]
+        elif cons_op.type.is_binary_elementwise_op() and cons_op.ifm2 == tens:
+            read_offset = cons_op.read_offsets[1]
+        else:
+            assert False
+        if read_offset is not None and (read_offset[-1] % 16) != 0:
+            return True
+    return False
+
+
+def avoid_nhcwb16_for_shapes(tens):
+    # check all producers/consumers to see if any op shape is preventing NHCWB16
+    for cons_op in tens.consumer_list:
+        if cons_op.ifm == tens:
+            cons_op_shape = cons_op.ifm_shapes[0]
+        elif cons_op.type.is_binary_elementwise_op() and cons_op.ifm2 == tens:
+            cons_op_shape = cons_op.ifm_shapes[1]
+        else:
+            assert False
+        if Shape4D(tens.shape) != cons_op_shape:
+            return True
+
+    for prod_op in tens.ops:
+        if Shape4D(tens.shape) != prod_op.ofm_shapes[0]:
+            return True
+
+    return False
+
+
+# Check if non linear format can be used
+def check_format_restrictions(tens, arch):
+    if len(tens.ops) < 1:
+        return
+    if tens.ops[0].type in (Op.Placeholder, Op.SubgraphInput, Op.Const) or any(
+        cons is None for cons in tens.consumer_list
+    ):
+        return
+
+    if not any(cons.run_on_npu for cons in tens.consumer_list):
+        return
+    if not any(prod.run_on_npu for prod in tens.ops):
+        return
+
+    # "Concat" ofm exception:
+    if avoid_nhcwb16_for_concat(tens):
+        return
+
+    # "Split" ifm exception:
+    if avoid_nhcwb16_for_split(tens):
+        return
+
+    # Shapes checking: check all producers/consumers are NHCWB16 compatible with tens.shape
+    if avoid_nhcwb16_for_shapes(tens):
+        return
+
+    for op in tens.consumer_list:
+        if op.type == Op.ReduceSum and tens.dtype == DataType.int32:
+            return
+        if op.type == Op.Reshape:
+            # Using NHCWB16 format for a no-op reshape is only an option if subsequent
+            # consumers do not also need to perform a reshape or if the OFM is going to
+            # be processed by CPU operations. No-op reshape consumers with empty lists
+            # (those that have no consumers, or null-consumers used as list terminators)
+            # must use normal NHWC output.
+
+            def incompatible_consumers(oper):
+                if oper and oper.type == Op.Reshape:
+                    for consumer in oper.outputs[0].consumer_list:
+                        yield from incompatible_consumers(consumer)
+                yield not oper or not oper.run_on_npu
+
+            if not any(incompatible_consumers(op)):
+
+                def get_rewrites(oper):
+                    if oper and oper.type == Op.Reshape:
+                        for consumer in oper.outputs[0].consumer_list:
+                            yield from get_rewrites(consumer)
+                        yield oper
+
+                # Detect no-op reshapes by comparing their full input and output tensor shapes.
+                inshape = op.ifm_shapes[0]
+                compatible_shape = [(inshape == oper.ofm_shapes[0]) for oper in get_rewrites(op)]
+                if not (compatible_shape and all(compatible_shape)):
+                    return
+            else:
+                return
+
+    tens.needs_linear_format = False
+
+
 def insert_copy_op_after_tens(tens):
     tens_cons_list_copy = tens.consumer_list.copy()
 
@@ -459,8 +550,6 @@
         assert batch_size * n_in_elems == elms
 
         op.ifm_shapes[0] = Shape4D([batch_size, 1, 1, n_in_elems])
-        if Shape4D(op.ifm.shape) != op.ifm_shapes[0]:
-            op.ifm.avoid_NHCWB16 = True
     return op
 
 
@@ -473,8 +562,6 @@
             h, w = batching_split.get(n, (1, n))
             op.ifm_shapes[0] = Shape4D([1, h, w, op.ifm_shapes[0].depth])
 
-            op.ifm.avoid_NHCWB16 = True
-
             # Reshape Weights to be 4D. IO becomes HWIO
             weight_tensor = op.inputs[1]
             weight_tensor.quant_values = np.expand_dims(np.expand_dims(weight_tensor.quant_values, axis=0), axis=0)
@@ -483,7 +570,6 @@
             n = op.ofm_shapes[0].batch
             h, w = batching_split.get(n, (1, n))
             op.ofm_shapes[0] = Shape4D([1, h, w, op.ofm_shapes[0].depth])
-            op.ofm.avoid_NHCWB16 = True
     return op
 
 
@@ -550,9 +636,6 @@
                 axis_4D[idx] = axis
             op.ofm_shapes[idx] = Shape4D(output_shape)
 
-        if op.ofm_shapes[idx] != Shape4D(out_tens.shape):
-            out_tens.avoid_NHCWB16 = True
-
     op.attrs["split_axis_4D"] = axis_4D
     return op
 
@@ -574,8 +657,6 @@
         for idx, out_tens in enumerate(op.outputs):
             op.ofm_shapes[idx] = Shape4D(desired_output_shape)
             axis_4D_list[idx] = axis_4D
-            if op.ofm_shapes[idx] != Shape4D(out_tens.shape):
-                out_tens.avoid_NHCWB16 = True
 
         op.attrs["split_axis_4D"] = axis_4D_list
     return op
@@ -662,7 +743,6 @@
         ifm_shape = op.ifm_shapes[0]
         # IFM
         op.ifm_shapes[0] = Shape4D([ifm_shape.batch, ifm_shape.height, ifm_shape.width // 2, ifm_shape.depth * 2])
-        op.ifm.avoid_NHCWB16 = True
 
         # Weights
         weight_shape = weight_tensor.shape
@@ -1129,16 +1209,12 @@
                 for ifm_idx, cons_ifm in enumerate(ifm_cons.inputs):
                     if cons_ifm == ifm:
                         ifm_cons.set_input_tensor(ofm, ifm_idx)
-            if op.ifm_shapes[0] != op.ofm_shapes[0]:
-                ofm.avoid_NHCWB16 = True
         else:
             # Bypassed Reshape by replacing ofm with ifm
             for cons in ofm.consumer_list:
                 for ifm_idx, cons_ifm in enumerate(cons.inputs):
                     if cons_ifm == ofm:
                         cons.set_input_tensor(ifm, ifm_idx)
-            if op.ifm_shapes[0] != op.ofm_shapes[0]:
-                ifm.avoid_NHCWB16 = True
 
 
 def check_reshapes(op, arch):
@@ -1339,7 +1415,7 @@
         create_avg_pool_for_concat(
             op, op.name + "_right", zero_tens, shape, shp_top.with_width(ofm_shape.width - right)
         )
-    ofm.avoid_NHCWB16 = True
+
     op.type = Op.ConcatTFLite
     return avgpool_op
 
@@ -1531,7 +1607,6 @@
         if h > 64:
             shape = [shape[0], 1, h * w, shape[3]]
             op.ifm_shapes[0] = Shape4D(shape)
-            inp.avoid_NHCWB16 = True
             if h > 256 and op.type == Op.AvgPool:
                 op.attrs.update({"ksize": (1, 1, h * w, 1), "filter_height": 1, "filter_width": h * w})
 
@@ -1688,6 +1763,11 @@
         rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_SplitSliceRead])
         sg.refresh_after_modification()
 
+    # Check Tensor Format restrictions
+    for sg in nng.subgraphs:
+        rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [check_format_restrictions], [])
+        sg.refresh_after_modification()
+
     # Post-optimisation operator debug tracing, and checking that no undesired reshapes are left in the graph
     for sg in nng.subgraphs:
         rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [check_reshapes, _record_optimised])
diff --git a/ethosu/vela/mark_tensors.py b/ethosu/vela/mark_tensors.py
index c1572f4..f3d5e85 100644
--- a/ethosu/vela/mark_tensors.py
+++ b/ethosu/vela/mark_tensors.py
@@ -41,6 +41,7 @@
         tens.purpose = purpose
     elif tens.purpose not in (purpose, TensorPurpose.LUT):
         assert 0, "Cannot resolve tensor purpose {} and {} for tensor {}".format(tens.purpose, purpose, tens)
+
     fmt = get_format(purpose, arch)
     tens.set_format(fmt, arch)
     tens.mem_area = arch.tensor_storage_mem_area[tens.purpose]
diff --git a/ethosu/vela/npu_performance.py b/ethosu/vela/npu_performance.py
index e315f1f..c83f8f5 100644
--- a/ethosu/vela/npu_performance.py
+++ b/ethosu/vela/npu_performance.py
@@ -389,7 +389,8 @@
     elem_size = tensor.dtype.size_in_bytes()
     is_ifm = direction == BandwidthDirection.Read
     tens = tensor.clone()
-    if not tens.avoid_NHCWB16:
+
+    if not tensor.needs_linear_format:
         tens.set_format(TensorFormat.NHCWB16, arch)
     strides = tens.get_strides(shape4D=shape4D)
 
diff --git a/ethosu/vela/operation_util.py b/ethosu/vela/operation_util.py
index 417f27e..c51a6b5 100644
--- a/ethosu/vela/operation_util.py
+++ b/ethosu/vela/operation_util.py
@@ -70,8 +70,6 @@
     op.set_output_tensor(ofm)
     op.ifm_shapes.append(ifm_shape)
     op.ofm_shapes.append(Shape4D(ofm.shape))
-    op.ifm.avoid_NHCWB16 = True
-    op.ofm.avoid_NHCWB16 = True
     return op
 
 
diff --git a/ethosu/vela/scheduler.py b/ethosu/vela/scheduler.py
index 6ee06e2..65d3313 100644
--- a/ethosu/vela/scheduler.py
+++ b/ethosu/vela/scheduler.py
@@ -671,8 +671,8 @@
         for pred_candidate in ps.dag_predecessors:
             if len(pred_candidate.outputs) == 1 and pred_candidate.outputs[0] == ifm_tensor:
                 # we found a predecessor that produces this IFM tensor
-                if not ifm_tensor.avoid_NHCWB16:
-                    # and NHCWB16 format is not to be avoided
+                if not ifm_tensor.needs_linear_format:
+                    # and NHCWB16 can be used
                     if len(pred_candidate.successors) == 1 and pred_candidate.successors[0] == ps:
                         # and it only has one successor, namely us
                         if pred_candidate.placement == PassPlacement.Npu:
@@ -965,7 +965,7 @@
                     if output.purpose != TensorPurpose.FeatureMap:
                         continue
 
-                    use_NHCWB16 = not output.avoid_NHCWB16
+                    use_NHCWB16 = not output.needs_linear_format
                     use_fast_storage = True
                     rewrites = []
                     for op in output.consumer_list:
diff --git a/ethosu/vela/softmax.py b/ethosu/vela/softmax.py
index 520ec23..13ca319 100644
--- a/ethosu/vela/softmax.py
+++ b/ethosu/vela/softmax.py
@@ -217,9 +217,7 @@
         ifm_shape = self.op.ifm_shapes[0]
         if ifm_shape.batch > 1:
             self.op.ifm_shapes[0] = ifm_shape.with_height(ifm_shape.batch * ifm_shape.height).with_batch(1)
-            self.op.ifm.avoid_NHCWB16 = True
             self.op.ofm_shapes[0] = self.op.ifm_shapes[0]
-            self.op.ofm.avoid_NHCWB16 = True
 
         if ifm.dtype in (DataType.uint8, DataType.int8) and ofm.dtype == ifm.dtype:
             return self.get_graph_8bit(ifm, ofm)
@@ -262,7 +260,6 @@
         sub_op_quantization = one_scale_quant.clone()
         sub_op_quantization.zero_point = 127
         ifm_max_shape = Shape4D([1, ifm_shape.height, ifm_shape.width, 1])
-        ifm_max.avoid_NHCWB16 = True
         sub_op = create_sub(
             f"{self.op.name}_sub{pass_number}",
             ifm,
@@ -449,7 +446,6 @@
 
         # PASS 1 - Sub
         ifm_max_shape = Shape4D([1, ifm_shape.height, ifm_shape.width, 1])
-        ifm_max.avoid_NHCWB16 = True
         sub1_ofm = add_op_get_ofm(
             create_sub(
                 f"{self.op.name}_sub{pass_number}",
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py
index e915363..15bd05e 100644
--- a/ethosu/vela/tensor.py
+++ b/ethosu/vela/tensor.py
@@ -372,7 +372,7 @@
         "block_traversal",
         "equivalence_id",
         "resampling_mode",
-        "avoid_NHCWB16",
+        "needs_linear_format",
     )
     AllocationQuantum = 16
 
@@ -418,7 +418,7 @@
         self.block_traversal: TensorBlockTraversal = TensorBlockTraversal.Default
         self.resampling_mode: resampling_mode = resampling_mode.NONE
 
-        self.avoid_NHCWB16: bool = False
+        self.needs_linear_format = True
 
     @property
     def address(self) -> int:
diff --git a/ethosu/vela/test/test_graph_optimiser.py b/ethosu/vela/test/test_graph_optimiser.py
index d9e171d..83a3dda 100644
--- a/ethosu/vela/test/test_graph_optimiser.py
+++ b/ethosu/vela/test/test_graph_optimiser.py
@@ -296,7 +296,6 @@
         "dilation_h_factor": 1,
     }
     pool_op = testutil.create_op(Op.AvgPool, [out], pool_out_tens, attrs)
-    pool_op.add_input_tensor(out)
     pad_op.run_on_npu = True
     pool_op.run_on_npu = True
     nng = testutil.create_graph([pad_op, pool_op])