MLBEDSW-4334 Non-linear format decision in graph opt.

Check if non linear tensor format can be used is
refactored.

-Flag avoid_NHCWB16 replaced with needs_linear_format
-Checking restrictions located to one function in graph optimiser.

Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com>
Change-Id: Iec5c7996a1a6039cad052197f1ae56f7c0290440
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index 56932db..dd540a7 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -104,8 +104,6 @@
 
         for idx, inp in enumerate(op.inputs):
             op.ifm_shapes[idx] = Shape4D(desired_shape)
-            if Shape4D(inp.shape) != op.ifm_shapes[idx]:
-                inp.avoid_NHCWB16 = True
         op.type = Op.PackReshaped
 
     inputs, axis = op.get_concat_inputs_axis()
@@ -125,12 +123,7 @@
         offset = concat_end
     assert ofm.shape[axis] == offset
 
-    # If axis corresponds to C-dimension, NHCWB16 can only be used in the output if all the concat_start's are a
-    # multiple of 16. This as, it is only then the address offset for the ofm, for all operations, will be 16 byte
-    # aligned. For other values of axis the address offsets will be 16 byte aligned, as they are all based on c = 0
-    # and those addresses are always 16 byte aligned due to the NHCWB16 format.
-    if axis == -1 or axis == (len(ofm.shape) - 1):
-        ofm.avoid_NHCWB16 = any(op2.write_offset.depth % 16 != 0 for op2 in ofm.ops if op2.write_offset is not None)
+    return op
 
 
 def rewrite_split_ops(tens, arch, nng):
@@ -171,10 +164,6 @@
 
                 offset_start[axis_4D] += split_op.ofm_shapes[idx][axis_4D]
 
-                # If start offset is not a multiple of 16 in the C-dimension, NHCWB16 need to be avoided in the input
-                if (offset_start[-1] % 16) != 0:
-                    inp.avoid_NHCWB16 = True
-
         new_op.read_offsets[0] = Shape4D.from_list(offset_start, 0)
         new_op.run_on_npu = True
         new_op.set_output_tensor(tens)
@@ -224,6 +213,108 @@
             DebugDatabase.add_optimised(op, avgpool_op)
 
 
+def avoid_nhcwb16_for_concat(tens):
+    # If axis corresponds to C-dimension, NHCWB16 can only be used in the output if all the concat_start's are a
+    # multiple of 16. This as, it is only then the address offset for the ofm, for all operations, will be 16 byte
+    # aligned. For other values of axis the address offsets will be 16 byte aligned, as they are all based on c = 0
+    # and those addresses are always 16 byte aligned due to the NHCWB16 format.
+    return any(op.write_offset.depth % 16 != 0 for op in tens.ops if op.write_offset is not None)
+
+
+def avoid_nhcwb16_for_split(tens):
+    # If read offset is not a multiple of 16 in the C-dimension, NHCWB16 need to be avoided in the input
+    for cons_op in tens.consumer_list:
+        if cons_op.ifm == tens:
+            read_offset = cons_op.read_offsets[0]
+        elif cons_op.type.is_binary_elementwise_op() and cons_op.ifm2 == tens:
+            read_offset = cons_op.read_offsets[1]
+        else:
+            assert False
+        if read_offset is not None and (read_offset[-1] % 16) != 0:
+            return True
+    return False
+
+
+def avoid_nhcwb16_for_shapes(tens):
+    # check all producers/consumers to see if any op shape is preventing NHCWB16
+    for cons_op in tens.consumer_list:
+        if cons_op.ifm == tens:
+            cons_op_shape = cons_op.ifm_shapes[0]
+        elif cons_op.type.is_binary_elementwise_op() and cons_op.ifm2 == tens:
+            cons_op_shape = cons_op.ifm_shapes[1]
+        else:
+            assert False
+        if Shape4D(tens.shape) != cons_op_shape:
+            return True
+
+    for prod_op in tens.ops:
+        if Shape4D(tens.shape) != prod_op.ofm_shapes[0]:
+            return True
+
+    return False
+
+
+# Check if non linear format can be used
+def check_format_restrictions(tens, arch):
+    if len(tens.ops) < 1:
+        return
+    if tens.ops[0].type in (Op.Placeholder, Op.SubgraphInput, Op.Const) or any(
+        cons is None for cons in tens.consumer_list
+    ):
+        return
+
+    if not any(cons.run_on_npu for cons in tens.consumer_list):
+        return
+    if not any(prod.run_on_npu for prod in tens.ops):
+        return
+
+    # "Concat" ofm exception:
+    if avoid_nhcwb16_for_concat(tens):
+        return
+
+    # "Split" ifm exception:
+    if avoid_nhcwb16_for_split(tens):
+        return
+
+    # Shapes checking: check all producers/consumers are NHCWB16 compatible with tens.shape
+    if avoid_nhcwb16_for_shapes(tens):
+        return
+
+    for op in tens.consumer_list:
+        if op.type == Op.ReduceSum and tens.dtype == DataType.int32:
+            return
+        if op.type == Op.Reshape:
+            # Using NHCWB16 format for a no-op reshape is only an option if subsequent
+            # consumers do not also need to perform a reshape or if the OFM is going to
+            # be processed by CPU operations. No-op reshape consumers with empty lists
+            # (those that have no consumers, or null-consumers used as list terminators)
+            # must use normal NHWC output.
+
+            def incompatible_consumers(oper):
+                if oper and oper.type == Op.Reshape:
+                    for consumer in oper.outputs[0].consumer_list:
+                        yield from incompatible_consumers(consumer)
+                yield not oper or not oper.run_on_npu
+
+            if not any(incompatible_consumers(op)):
+
+                def get_rewrites(oper):
+                    if oper and oper.type == Op.Reshape:
+                        for consumer in oper.outputs[0].consumer_list:
+                            yield from get_rewrites(consumer)
+                        yield oper
+
+                # Detect no-op reshapes by comparing their full input and output tensor shapes.
+                inshape = op.ifm_shapes[0]
+                compatible_shape = [(inshape == oper.ofm_shapes[0]) for oper in get_rewrites(op)]
+                if not (compatible_shape and all(compatible_shape)):
+                    return
+            else:
+                return
+
+    tens.needs_linear_format = False
+
+
 def insert_copy_op_after_tens(tens):
     tens_cons_list_copy = tens.consumer_list.copy()
 
@@ -459,8 +550,6 @@
         assert batch_size * n_in_elems == elms
 
         op.ifm_shapes[0] = Shape4D([batch_size, 1, 1, n_in_elems])
-        if Shape4D(op.ifm.shape) != op.ifm_shapes[0]:
-            op.ifm.avoid_NHCWB16 = True
     return op
 
 
@@ -473,8 +562,6 @@
             h, w = batching_split.get(n, (1, n))
             op.ifm_shapes[0] = Shape4D([1, h, w, op.ifm_shapes[0].depth])
 
-            op.ifm.avoid_NHCWB16 = True
-
             # Reshape Weights to be 4D. IO becomes HWIO
             weight_tensor = op.inputs[1]
             weight_tensor.quant_values = np.expand_dims(np.expand_dims(weight_tensor.quant_values, axis=0), axis=0)
@@ -483,7 +570,6 @@
             n = op.ofm_shapes[0].batch
             h, w = batching_split.get(n, (1, n))
             op.ofm_shapes[0] = Shape4D([1, h, w, op.ofm_shapes[0].depth])
-            op.ofm.avoid_NHCWB16 = True
     return op
 
 
@@ -550,9 +636,6 @@
                 axis_4D[idx] = axis
             op.ofm_shapes[idx] = Shape4D(output_shape)
 
-        if op.ofm_shapes[idx] != Shape4D(out_tens.shape):
-            out_tens.avoid_NHCWB16 = True
-
     op.attrs["split_axis_4D"] = axis_4D
     return op
 
@@ -574,8 +657,6 @@
         for idx, out_tens in enumerate(op.outputs):
             op.ofm_shapes[idx] = Shape4D(desired_output_shape)
             axis_4D_list[idx] = axis_4D
-            if op.ofm_shapes[idx] != Shape4D(out_tens.shape):
-                out_tens.avoid_NHCWB16 = True
 
         op.attrs["split_axis_4D"] = axis_4D_list
     return op
@@ -662,7 +743,6 @@
         ifm_shape = op.ifm_shapes[0]
         # IFM
         op.ifm_shapes[0] = Shape4D([ifm_shape.batch, ifm_shape.height, ifm_shape.width // 2, ifm_shape.depth * 2])
-        op.ifm.avoid_NHCWB16 = True
 
         # Weights
         weight_shape = weight_tensor.shape
@@ -1129,16 +1209,12 @@
                 for ifm_idx, cons_ifm in enumerate(ifm_cons.inputs):
                     if cons_ifm == ifm:
                         ifm_cons.set_input_tensor(ofm, ifm_idx)
-            if op.ifm_shapes[0] != op.ofm_shapes[0]:
-                ofm.avoid_NHCWB16 = True
         else:
             # Bypassed Reshape by replacing ofm with ifm
             for cons in ofm.consumer_list:
                 for ifm_idx, cons_ifm in enumerate(cons.inputs):
                     if cons_ifm == ofm:
                         cons.set_input_tensor(ifm, ifm_idx)
-            if op.ifm_shapes[0] != op.ofm_shapes[0]:
-                ifm.avoid_NHCWB16 = True
 
 
 def check_reshapes(op, arch):
@@ -1339,7 +1415,7 @@
         create_avg_pool_for_concat(
             op, op.name + "_right", zero_tens, shape, shp_top.with_width(ofm_shape.width - right)
         )
-    ofm.avoid_NHCWB16 = True
+
     op.type = Op.ConcatTFLite
     return avgpool_op
 
@@ -1531,7 +1607,6 @@
         if h > 64:
             shape = [shape[0], 1, h * w, shape[3]]
             op.ifm_shapes[0] = Shape4D(shape)
-            inp.avoid_NHCWB16 = True
             if h > 256 and op.type == Op.AvgPool:
                 op.attrs.update({"ksize": (1, 1, h * w, 1), "filter_height": 1, "filter_width": h * w})
 
@@ -1688,6 +1763,11 @@
         rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_SplitSliceRead])
         sg.refresh_after_modification()
 
+    # Check Tensor Format restrictions
+    for sg in nng.subgraphs:
+        rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [check_format_restrictions], [])
+        sg.refresh_after_modification()
+
     # Post-optimisation operator debug tracing, and checking that no undesired reshapes are left in the graph
     for sg in nng.subgraphs:
         rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [check_reshapes, _record_optimised])