MLBEDSW-3194: Updated elementwise IFM banks count

Signed-off-by: Andreas Nevalainen <andreas.nevalainen@arm.com>
Change-Id: Ie404a0c13e7c7de0eff649f77e0147a0f3d73acd
diff --git a/ethosu/vela/register_command_stream_generator.py b/ethosu/vela/register_command_stream_generator.py
index 073b50f..4f3fe7d 100644
--- a/ethosu/vela/register_command_stream_generator.py
+++ b/ethosu/vela/register_command_stream_generator.py
@@ -585,7 +585,7 @@
                         # Set IFM2_IB_START to the latter half of the IB space
                         ifm_ib_start = shared_buffer.bank_locations[SharedBufferArea.IFM]
                         emit.cmd0_with_param(
-                            cmd0.NPU_SET_IFM2_IB_START, (shram_required - ifm_ib_start) / 2 + ifm_ib_start
+                            cmd0.NPU_SET_IFM2_IB_START, (shram_required - ifm_ib_start) // shared_buffer.ifm_count + ifm_ib_start
                         )
 
                     emit.cmd0_with_param(cmd0.NPU_SET_IFM2_BROADCAST, ifm2_broadcast)
diff --git a/ethosu/vela/shared_buffer_allocation.py b/ethosu/vela/shared_buffer_allocation.py
index aa5f4c8..f52d3a9 100644
--- a/ethosu/vela/shared_buffer_allocation.py
+++ b/ethosu/vela/shared_buffer_allocation.py
@@ -47,6 +47,7 @@
         self.kernel = Kernel(1, 1)
         self.is_elementwise = ps.npu_block_type == NpuBlockType.ElementWise
         self.uses_lut = False
+        self.ifm_count = 1
 
         if ps.primary_op:
             strides = ps.primary_op.attrs.get("strides", strides)
@@ -82,11 +83,19 @@
         if ifm_tensor:
             self.ifm_resampling_mode = ifm_tensor.resampling_mode
             self.ifm_bits = ifm_tensor.dtype.size_in_bits()
-            if ifm_tensor.shape == [] and self.is_elementwise:
-                # Elementwise operator with scalar in ifm, use ifm2 depth
-                self.ifm_depth = ifm2_tensor.shape[-1]
-            else:
+
+            if ifm_tensor.shape != []:
                 self.ifm_depth = ifm_tensor.shape[-1]
+
+            if self.is_elementwise:
+                self.ifm_count = 2
+                if ifm_tensor.shape == []: # Scalar in ifm1
+                    assert ifm2_tensor
+                    self.ifm_depth = ifm2_tensor.shape[-1]
+                    self.ifm_count = 1
+                elif not ifm2_tensor or ifm2_tensor.shape == []: # Scalar in ifm2
+                    self.ifm_count = 1
+
             if self.ifm_bits == 16:
                 if ps.npu_block_type != NpuBlockType.Pooling and has_scale:
                     self.use_accumulator_element = SHRAMElements.Acc40
@@ -137,7 +146,7 @@
         acc_banks = ofm_config.banks[self.use_accumulator_element]
 
         # Update bank counts for IFM and Accumulator
-        self.banks_required[SharedBufferArea.IFM] = ifm_config.banks[self.use_ifm_element]
+        self.banks_required[SharedBufferArea.IFM] = ifm_config.banks[self.use_ifm_element] * self.ifm_count
         self.banks_required[SharedBufferArea.Accumulators] = 0 if self.is_elementwise else acc_banks
 
         # Validating calculates bank layout and returns validity