MLBEDSW-3654 Add/use op ifm/ofm shapes

Add ifm/ofm shapes to op
Changed to rely on these shapes

Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com>
Change-Id: I571535a1dcadc2bdb04a3c727a8e1c49703b174d
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index 30c32ac..be26a26 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -27,6 +27,7 @@
 from .errors import VelaError
 from .numeric_util import full_shape
 
+
 if TYPE_CHECKING:
     from .tensor import Tensor
 
@@ -129,7 +130,7 @@
     Concat = OperatorInfo(indices=CONCAT_INDICES)
     ConcatEmbeddings = OperatorInfo()
     ConcatSliceWrite = OperatorInfo(indices=IFM_INDICES)
-    ConcatTFLite = OperatorInfo()
+    ConcatTFLite = OperatorInfo(indices=CONCAT_INDICES)
     Const = OperatorInfo()  # Constant tensor, only used in CPU subgraphs
     Conv2D = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=IFM_WEIGHTS_INDICES)
     Conv2DBackpropInput = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=CONV2D_BACKPROP_INDICES)
@@ -197,7 +198,7 @@
     NonMaxSuppressionV5 = OperatorInfo()
     NotEqual = OperatorInfo()
     OneHot = OperatorInfo()
-    Pack = OperatorInfo()
+    Pack = OperatorInfo(indices=IFM_INDICES)
     PackReshaped = OperatorInfo(indices=IFM_INDICES)
     Pad = OperatorInfo()
     PadV2 = OperatorInfo()
@@ -260,7 +261,7 @@
     UnidirectionalSequenceLstm = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
     UnidirectionalSequenceRnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
     Unique = OperatorInfo()
-    Unpack = OperatorInfo()
+    Unpack = OperatorInfo(indices=IFM_INDICES)
     UnpackReshaped = OperatorInfo(indices=IFM_INDICES)
     Where = OperatorInfo()
     While = OperatorInfo()
@@ -305,14 +306,17 @@
         return self.is_relu_op() or self in (Op.Tanh, Op.Sigmoid, Op.Softmax, Op.LUT)
 
     def is_split_op(self):
-        return self in (Op.Split, Op.SplitV, Op.StridedSlice, Op.Slice, Op.UnpackReshaped)
+        return self in (Op.Split, Op.SplitV, Op.StridedSlice, Op.Slice, Op.UnpackReshaped, Op.Unpack)
 
     def is_concat_op(self):
-        return self in (Op.Concat, Op.ConcatTFLite, Op.PackReshaped)
+        return self in (Op.Concat, Op.ConcatTFLite, Op.PackReshaped, Op.Pack)
 
     def needs_bias(self):
         return bool(self.info.indices.biases)
 
+    def needs_shapes(self):
+        return bool(self.info.indices.ifms)
+
     @classmethod
     def op_set(cls, predicate):
         # Returns the set of all operator codes that fulfill the given predicate
@@ -400,6 +404,8 @@
         "forced_output_quantization",
         "activation_lut",
         "_kernel",
+        "ifm_shapes",
+        "ofm_shapes",
     )
 
     def __init__(self, op_type: Op, name: str):
@@ -421,6 +427,8 @@
         self.op_index = None  # input network operator index
         self.activation_lut = None
         self._kernel = None
+        self.ifm_shapes = []
+        self.ofm_shapes = []
 
     def clone(self, suffix="_clone"):
         res = Operation(self.type, self.name + suffix)
@@ -697,3 +705,35 @@
         lines += _print_tensors(self.outputs)
 
         raise VelaError("\n".join(lines))
+
+    def set_ifm_ofm_shapes(self):
+        ifm_tensor, ifm2_tensor, weight_tensor, ofm_tensor = self.get_ifm_ifm2_weights_ofm()
+
+        # set all shapes to op, as 4D
+        if self.type == Op.FullyConnected:
+            n_in_elems = weight_tensor.shape[-2]
+            elms = ifm_tensor.elements()
+            batch_size = elms // n_in_elems
+            assert batch_size * n_in_elems == elms
+
+            self.ifm_shapes.append([batch_size, 1, 1, n_in_elems])
+            self.ofm_shapes.append(ofm_tensor.get_full_shape())
+        elif self.type == Op.Softmax:
+            self.ifm_shapes.append(ifm_tensor.get_full_shape())
+            self.ofm_shapes.append(ofm_tensor.get_full_shape())
+        elif self.type.is_split_op or self.type.is_concat_op():
+            for inp in self.inputs:
+                if inp is not None:
+                    self.ifm_shapes.append(full_shape(4, inp.shape, 1))
+                else:
+                    self.ifm_shapes.append(None)
+            for out in self.outputs:
+                if out is not None:
+                    self.ofm_shapes.append(full_shape(4, out.shape, 1))
+                else:
+                    self.ofm_shapes.append(None)
+        else:
+            self.ifm_shapes.append(full_shape(4, ifm_tensor.shape, 1))
+            if ifm2_tensor is not None:
+                self.ifm_shapes.append(full_shape(4, ifm2_tensor.shape, 1))
+            self.ofm_shapes.append(full_shape(4, ofm_tensor.shape, 1))