[MLBEDSW-2928] Add batching to softmax

Added batching to softmax by reshaping the input.

Signed-off-by: Fredrik Svedberg <fredrik.svedberg@arm.com>
Change-Id: I0b516f9bf2410fb86372b229beba4a7280c498cc
diff --git a/ethosu/vela/softmax.py b/ethosu/vela/softmax.py
index 2834f8c..9e8b846 100644
--- a/ethosu/vela/softmax.py
+++ b/ethosu/vela/softmax.py
@@ -201,6 +201,14 @@
         ifm = self.op.inputs[0]
         ofm = self.op.outputs[0]
 
+        # Reshape ifm/ofm (if needed)
+        full_shape = ifm.get_full_shape()
+        if full_shape[0] > 1:
+            full_shape[1] *= full_shape[0]
+            full_shape[0] = 1
+        ifm = create_reshape_tensor(ifm, full_shape)
+        ofm = create_reshape_tensor(ofm, full_shape, False)
+
         if ifm.dtype in (DataType.uint8, DataType.int8) and ofm.dtype == ifm.dtype:
             return self.get_graph_8bit(ifm, ofm)
         elif ifm.dtype == DataType.int16 and ofm.dtype == DataType.int16:
@@ -211,8 +219,6 @@
 
     def get_graph_8bit(self, ifm, ofm):
         exp_lut = self.generate_exp_table(self.op.attrs.get("beta", 1.0), ifm.quantization.scale_f32)
-        ifm = create_reshape_tensor(ifm, ifm.get_full_shape())
-        ofm = create_reshape_tensor(ofm, ofm.get_full_shape(), False)
         no_scale_quant = ifm.quantization.clone()
         no_scale_quant.scale_f32 = None
         no_scale_quant.zero_point = 0
@@ -242,7 +248,7 @@
         # PASS 1 - Sub+LUT(exp)
         sub_op = Operation("SubAct", self.op.name + "_sub1")
         sub_op.add_input_tensor(ifm)
-        sub_op.add_input_tensor(ifm_max)
+        sub_op.add_input_tensor(create_reshape_tensor(ifm_max, [1, ifm.shape[1], ifm.shape[2], 1]))
         sub_op.set_activation_lut(
             create_const_tensor(
                 sub_op.name + "_lut", [1, 1, 1, 256], DataType.int32, exp_lut, np.int32, TensorPurpose.LUT
@@ -463,8 +469,6 @@
         return shr30_op
 
     def get_graph_int16(self, ifm, ofm):
-        ifm = create_reshape_tensor(ifm, ifm.get_full_shape())
-        ofm = create_reshape_tensor(ofm, ofm.get_full_shape(), False)
         no_scale_quant = ifm.quantization.clone()
         no_scale_quant.scale_f32 = None