Depthwise channel pre-multiplication

Resolves: COMPMID-6337
Change-Id: Ie9097b3f56e8071426c621386a5988bd7f7e8ef2
Signed-off-by: Michael Tyler <michael.tyler@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9852
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Viet-Hoa Do <viet-hoa.do@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
diff --git a/Android.bp b/Android.bp
index 85efa54..ccf513a 100644
--- a/Android.bp
+++ b/Android.bp
@@ -337,6 +337,7 @@
         "src/core/NEON/kernels/arm_conv/depthwise/depthwise_u8s8u8q.cpp",
         "src/core/NEON/kernels/arm_conv/depthwise/interleaves/generic.cpp",
         "src/core/NEON/kernels/arm_conv/depthwise/interleaves/generic_quantized_dot_product.cpp",
+        "src/core/NEON/kernels/arm_conv/depthwise/premultiply.cpp",
         "src/core/NEON/kernels/arm_conv/pooling/kernels/cpp_nhwc_1x1_stride_any_depthfirst/generic.cpp",
         "src/core/NEON/kernels/arm_conv/pooling/pooling_fp16.cpp",
         "src/core/NEON/kernels/arm_conv/pooling/pooling_fp32.cpp",
diff --git a/filelist.json b/filelist.json
index b9196cb..a8f6ffc 100644
--- a/filelist.json
+++ b/filelist.json
@@ -1282,6 +1282,7 @@
               "src/core/NEON/kernels/arm_conv/depthwise/interleaves/a64_u8q_3x3_dot.cpp",
               "src/core/NEON/kernels/arm_conv/depthwise/interleaves/generic.cpp",
               "src/core/NEON/kernels/arm_conv/depthwise/interleaves/generic_quantized_dot_product.cpp",
+              "src/core/NEON/kernels/arm_conv/depthwise/premultiply.cpp",
               "src/cpu/kernels/depthwiseconv2d/generic/neon/impl.cpp"
               ],
               "fp16":["src/cpu/kernels/depthwiseconv2d/generic/neon/fp16.cpp"],
diff --git a/src/BUILD.bazel b/src/BUILD.bazel
index 85e5650..3701694 100644
--- a/src/BUILD.bazel
+++ b/src/BUILD.bazel
@@ -487,6 +487,7 @@
 	"core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8s8u8q_nhwc_5x5_s1_output2x2_mla_depthfirst/generic.cpp",
 	"core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8s8u8q_nhwc_generic_output9_mla_depthfirst/generic.cpp",
 	"core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8s8u8q_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst/generic.cpp",
+	"core/NEON/kernels/arm_conv/depthwise/premultiply.cpp",
 	"core/NEON/kernels/arm_conv/pooling/kernels/a64_fp16_nhwc_avg_3x3_s1_output2x2_depthfirst/generic.cpp",
 	"core/NEON/kernels/arm_conv/pooling/kernels/a64_fp16_nhwc_avg_generic_depthfirst/generic.cpp",
 	"core/NEON/kernels/arm_conv/pooling/kernels/a64_fp16_nhwc_max_2x2_s1_output2x2_depthfirst/generic.cpp",
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index c847a62..085624c 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -479,6 +479,7 @@
 	core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8s8u8q_nhwc_5x5_s1_output2x2_mla_depthfirst/generic.cpp
 	core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8s8u8q_nhwc_generic_output9_mla_depthfirst/generic.cpp
 	core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8s8u8q_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst/generic.cpp
+	core/NEON/kernels/arm_conv/depthwise/premultiply.cpp
 	core/NEON/kernels/arm_conv/pooling/kernels/a64_fp16_nhwc_avg_3x3_s1_output2x2_depthfirst/generic.cpp
 	core/NEON/kernels/arm_conv/pooling/kernels/a64_fp16_nhwc_avg_generic_depthfirst/generic.cpp
 	core/NEON/kernels/arm_conv/pooling/kernels/a64_fp16_nhwc_max_2x2_s1_output2x2_depthfirst/generic.cpp
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/depthfirst_driver.hpp b/src/core/NEON/kernels/arm_conv/depthwise/depthfirst_driver.hpp
index b6f45c6..592ee72 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/depthfirst_driver.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/depthfirst_driver.hpp
@@ -72,10 +72,10 @@
   std::unique_ptr<const IDepthfirstStrategy> m_strat;
 
   /* Compute the amount of working space required for a single thread. */
-  virtual size_t get_working_size_per_thread(unsigned int n_input_channels) const = 0;
+  virtual size_t get_working_size_per_thread() const = 0;
 
   /* Initialise the working space for a thread. */
-  virtual void initialise_working_space(void *, unsigned int n_input_channels) const = 0;
+  virtual void initialise_working_space(void *) const = 0;
 
   /* Compute a portion of the output tensor with padding. */
   virtual void compute_tile_padded(
@@ -164,8 +164,8 @@
   {
     // Get and initialise the working space for this thread.
     void *thread_working_space =
-      static_cast<uint8_t *>(working_space) + thread_id * this->get_working_size_per_thread(args.input_channels);
-    this->initialise_working_space(thread_working_space, args.input_channels);
+      static_cast<uint8_t *>(working_space) + thread_id * this->get_working_size_per_thread();
+    this->initialise_working_space(thread_working_space);
 
     // Construct convenient representations of the input/output tensors.
     TensorSpec<const TInput *> input_tensor(reinterpret_cast<const TInput *>(input), ld_input_row, ld_input_col);
@@ -189,7 +189,9 @@
         const bool pad_input_top = start_input_i < 0;
         const int end_input_i = start_input_i + m_strat->get_input_rows();
         const bool pad_input_bottom = static_cast<int>(args.input_rows) < end_input_i;
-        const bool pad_row = pad_input_top || pad_input_bottom || pad_output_bottom;
+        // We only need to account for input padding if direct padding is not supported.
+        const bool pad_row = ((pad_input_top || pad_input_bottom) && !this->supports_direct_padding())
+                || pad_output_bottom;
 
         // Iterate over the columns of the output tensor; we attempt to grab as
         // much as possible of the unpadded regions, so the loop structure is a
@@ -202,7 +204,7 @@
 
           // Determine if we can process a number of unpadded tiles in one go.
           int n_unpadded_tiles = 0;
-          if (!pad_input_left)
+          if ((!pad_input_left) || this->supports_direct_padding())
           {
             // Determine the maximum number of tiles we could handle.
             n_unpadded_tiles = (args.output_cols - start_output_j) / m_strat->get_output_cols();
@@ -273,9 +275,14 @@
   {
   }
 
-  size_t get_working_size(unsigned int n_threads, unsigned int n_input_channels) const override final
+  size_t get_working_size(unsigned int n_threads) const override final
   {
-    return n_threads * this->get_working_size_per_thread(n_input_channels);
+    return n_threads * this->get_working_size_per_thread();
+  }
+
+  virtual bool supports_direct_padding() const
+  {
+    return false;
   }
 };
 
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/depthwise_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/depthwise_depthfirst.hpp
index 2620b48..7b00c9a 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/depthwise_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/depthwise_depthfirst.hpp
@@ -115,7 +115,7 @@
   {
     return interleaves::PackingArguments(
       this->get_kernel_rows(), this->get_kernel_cols(), sizeof(TWeight),
-      false, sizeof(int32_t),  // Don't pack the bias
+      false, sizeof(int32_t), this->uses_premultiply(),  // Don't pack the bias
       this->get_vl_type(), sizeof(int32_t), this->get_accumulator_depth_vl(),
       [this] (unsigned int idx, unsigned int &x, unsigned int &y) -> bool
       { return this->get_kernel_packing_point(idx, x, y); }
@@ -162,6 +162,64 @@
   inline OutputStage &get_output_stage(void) { return m_os; }
   inline const OutputStage &get_output_stage(void) const { return m_os; }
 
+  bool uses_intermediate_array() const
+  {
+    return this->m_args.channel_multiplier != 1 && this->uses_premultiply();
+  }
+
+  virtual void fill_inptr_array(const DepthwiseArgs &args,
+    const TensorSpec<const TInput *> &input,
+    const TInput **inptr_array, TInput *input_buffer,
+    const unsigned int input_i, const unsigned int input_j,
+    const unsigned int input_pad_top, const unsigned int input_pad_left) const = 0;
+
+  void initialise_inptr_array(const DepthwiseArgs &args,
+      unsigned int output_channel_start, unsigned int output_channel_end,
+      const TensorSpec<const TInput *> &input,
+      const TInput **inptr_array, TInput *input_buffer, TInput *intermediate_buffer,
+      const unsigned int input_i, const unsigned int input_j,
+      const unsigned int input_pad_top, const unsigned int input_pad_left,
+      Tile<TInput> &multiplied_input
+  ) const
+  {
+    // Compute the input pointer array
+    const auto input_channel_start = output_channel_start / args.channel_multiplier;
+
+    const auto last_valid_row = std::min(input_pad_top + args.input_rows - input_i, this->m_strat->get_input_rows());
+    const auto last_valid_col = std::min(input_pad_left + args.input_cols - input_j, this->m_strat->get_input_cols());
+
+    const auto tile_rows = last_valid_row - input_pad_top;
+    const auto tile_cols = last_valid_col - input_pad_left;
+
+    const auto tile_channels = output_channel_end - output_channel_start;
+
+    TensorSpec<const TInput *> tile_tensor(0, 0, 0);
+    if (this->uses_intermediate_array()) {
+      multiplied_input = Tile<TInput>(intermediate_buffer, tile_rows, tile_cols, tile_channels);
+      multiplied_input.load_from(input.base, input.ld_row, input.ld_col,
+                                 args.input_rows, args.input_cols,
+                                 input_i, input_j, args.channel_multiplier);
+
+      tile_tensor = TensorSpec<const TInput *>(
+        multiplied_input.array,
+        tile_cols * tile_channels, tile_channels
+      );
+    } else {
+      tile_tensor = TensorSpec<const TInput *>(
+        input.base + input_i*input.ld_row + input_j*input.ld_col + input_channel_start,
+        input.ld_row, input.ld_col
+      );
+    }
+
+    fill_inptr_array(args,
+      tile_tensor,
+      inptr_array, input_buffer,
+      input_i, input_j,
+      input_pad_top,
+      input_pad_left
+    );
+  }
+
   public:
   DepthwiseDepthfirstCommon(StratType *const strat, const DepthwiseArgs &args, const OutputStage &os)
   : DepthfirstDriver<TInput, TWeight, TOutput>(strat, args), m_os(os)
@@ -321,6 +379,7 @@
     OutputArrayElement<TOutput>,
     depthwise_depthfirst::InputArrayElement<TInput>,
     InputBufferElement<TInput>,
+    IntermediateBufferElement<TInput>,
     typename depthwise_depthfirst::WorkspaceFinalElement<TAccum, OutputStage>::Element
   >;
   using WorkingSpace = typename WorkspaceManager::WorkspaceType;
@@ -347,25 +406,46 @@
     depthwise_depthfirst::stash_bias(this->get_output_stage(), biases);
   }
 
-  size_t get_working_size_per_thread(const unsigned int n_input_channels) const override
+  size_t get_working_size_per_thread() const override
   {
     DepthwiseArgs args(this->m_args);
-    args.input_channels = n_input_channels;
     return WorkspaceManager::get_sizeof_workspace(
       WorkspaceArgs<IDepthfirstStrategy, OutputStage>(this->m_strat.get(), args, this->get_output_stage())
     );
   }
 
-  void initialise_working_space(void *buffer, unsigned int n_input_channels) const override
+  void initialise_working_space(void *buffer) const override
   {
     DepthwiseArgs args(this->m_args);
-    args.input_channels = n_input_channels;
     WorkspaceManager::initialise(
       buffer, WorkspaceArgs<IDepthfirstStrategy, OutputStage>(this->m_strat.get(), args, this->get_output_stage())
     );
   }
 
+  virtual bool supports_direct_padding() const override
+  {
+    using Invoker = depthwise_depthfirst::Invoke<TInput, TWeight, TOutput, TAccum, OutputStage>;
+    return Invoker::supports_direct_kernel && this->uses_intermediate_array();
+  }
+
   protected:
+
+  void fill_inptr_array(const DepthwiseArgs &args,
+    const TensorSpec<const TInput *> &input,
+    const TInput **inptr_array, TInput *input_buffer,
+    const unsigned int input_i, const unsigned int input_j,
+    const unsigned int input_pad_top, const unsigned int input_pad_left) const override
+  {
+    fill_pointer_array<const TInput>(
+      inptr_array, this->m_strat->get_input_rows(), this->m_strat->get_input_cols(),
+      input.base,
+      input.ld_row, input.ld_col,
+      input_buffer,
+      input_pad_top, args.input_rows - input_i,
+      input_pad_left, args.input_cols - input_j
+    );
+  }
+
   void compute_tile_padded(
     const DepthwiseArgs &args,
     unsigned int output_i, unsigned int output_j,
@@ -380,8 +460,6 @@
     auto ws = reinterpret_cast<WorkingSpace *>(working_space_raw);
 
     // Compute the input pointer array
-    const auto input_channel_start = output_channel_start / args.channel_multiplier;
-
     const int ii = static_cast<int>(output_i * args.stride_rows) - args.padding.top;
     const auto input_pad_top = static_cast<unsigned int>(ii < 0 ? -ii : 0);
     const auto input_i = static_cast<unsigned int>(ii < 0 ? 0 : ii);
@@ -390,14 +468,10 @@
     const auto input_pad_left = static_cast<unsigned int>(ij < 0 ? -ij : 0);
     const auto input_j = static_cast<unsigned int>(ij < 0 ? 0 : ij);
 
-    fill_pointer_array<const TInput>(
-      ws->inptr_array, this->m_strat->get_input_rows(), this->m_strat->get_input_cols(),
-      input.base + input_i*input.ld_row + input_j*input.ld_col + input_channel_start,
-      input.ld_row, input.ld_col,
-      ws->input_buffer,
-      input_pad_top, args.input_rows - input_i,
-      input_pad_left, args.input_cols - input_j
-    );
+    Tile<TInput> multiplied_input;
+    this->initialise_inptr_array(args, output_channel_start, output_channel_end, input,
+      ws->inptr_array, ws->input_buffer, ws->intermediate_buffer,
+      input_i, input_j, input_pad_top, input_pad_left, multiplied_input);
 
     // Compute the output pointer array
     fill_pointer_array(
@@ -432,12 +506,11 @@
     const auto os = this->get_output_stage();
 
     // Compute top and bottom padding; hence fill in the initial pointer arrays.
-    const auto input_channel_start = output_channel_start / args.channel_multiplier;
     const int ii = static_cast<int>(output_i * args.stride_rows) - args.padding.top;
     const auto input_pad_top = static_cast<unsigned int>(ii < 0 ? -ii : 0);
 
     const auto input_i = static_cast<unsigned int>(ii < 0 ? 0 : ii);
-    const auto input_j = output_j * args.stride_cols - args.padding.left;
+    auto input_j = output_j * args.stride_cols - args.padding.left;
 
     // Valid input rows is the smallest of the input rows that aren't padding for this tile, and the number of rows
     // available.
@@ -447,14 +520,10 @@
     const auto input_point_stride = input.ld_col * this->m_strat->get_output_cols() * args.stride_cols;
     const auto output_point_stride = output.ld_col * this->m_strat->get_output_cols();
 
-    fill_pointer_array<const TInput>(
-      ws->inptr_array, this->m_strat->get_input_rows(), this->m_strat->get_input_cols(),
-      input.base + input_i*input.ld_row + input_j*input.ld_col + input_channel_start,
-      input.ld_row, input.ld_col,
-      ws->input_buffer,
-      input_pad_top, args.input_rows - input_i,
-      0, args.input_cols - input_j  // No left padding
-    );
+    Tile<TInput> multiplied_input;
+    this->initialise_inptr_array(args, output_channel_start, output_channel_end, input,
+      ws->inptr_array, ws->input_buffer, ws->intermediate_buffer,
+      input_i, input_j, input_pad_top, 0, multiplied_input);
 
     fill_pointer_array(
       ws->outptr_array, this->m_strat->get_output_rows(), this->m_strat->get_output_cols(),
@@ -473,16 +542,25 @@
       );
 
       // Update all unpadded pointers
-      {
-        auto ptr = ws->inptr_array + strat->get_input_cols() * input_pad_top;
-        for (auto n = input_pad_top; n < (valid_input_rows + input_pad_top); n++)
+      if (this->uses_intermediate_array()) {
+        input_j += input_point_stride / input.ld_col;
+        multiplied_input.load_from(input.base,
+          input.ld_row, input.ld_col,
+          args.input_rows, args.input_cols,
+          input_i, input_j, args.channel_multiplier);
+      } else {
         {
-          for (auto m = 0u; m < strat->get_input_cols(); m++)
+          auto ptr = ws->inptr_array + strat->get_input_cols() * input_pad_top;
+          for (auto n = input_pad_top; n < (valid_input_rows + input_pad_top); n++)
           {
-            *(ptr++) += input_point_stride;
+            for (auto m = 0u; m < strat->get_input_cols(); m++)
+            {
+              *(ptr++) += input_point_stride;
+            }
           }
         }
       }
+
       {
         auto ptr = ws->outptr_array;
         for (auto n = 0u; n < valid_output_rows * strat->get_output_cols(); n++)
@@ -511,6 +589,13 @@
 
     if (Invoker::supports_direct_kernel)
     {
+      PaddingValues tile_padding = {
+              args.kernel_cols / 2,
+              args.kernel_rows / 2,
+              args.kernel_cols / 2,
+              args.kernel_rows / 2
+      };
+
       // If the direct kernel is supported, then use it.
       // Compute the base pointers we'll use in the tile.
       auto outptr = output.base + output_channel_start + output_i * output.ld_row + output_j * output.ld_col;
@@ -518,11 +603,31 @@
       const int start_input_j = output_j * args.stride_cols - args.padding.left;
       auto inptr = input.base + output_channel_start + start_input_i * input.ld_row + start_input_j * input.ld_col;
 
+      auto ld_row = input.ld_row;
+      auto ld_col = input.ld_col;
+
+      const auto tile_rows = this->m_strat->get_output_rows() * args.stride_rows * n_tile_rows + tile_padding.top + tile_padding.bottom;
+      const auto tile_cols = this->m_strat->get_output_cols() * args.stride_cols * n_tile_cols + tile_padding.left + tile_padding.right;
+      const auto tile_channels = output_channel_end - output_channel_start;
+
+      Tile<TInput> multiplied_input;
+      if (this->uses_intermediate_array()) {
+        multiplied_input = Tile<TInput>(ws->intermediate_buffer, tile_rows, tile_cols, tile_channels);
+        multiplied_input.load_from(input.base,
+          input.ld_row, input.ld_col,
+          args.input_rows, args.input_cols,
+          start_input_i, start_input_j, args.channel_multiplier);
+
+        ld_row = tile_cols * tile_channels;
+        ld_col = tile_channels;
+        inptr = multiplied_input.array;
+      }
+
       // Execute the kernel
       Invoker::direct(
         strat, ws, os,
         n_tile_rows, n_tile_cols,
-        inptr, input.ld_row, input.ld_col,
+        inptr, ld_row, ld_col,
         outptr, output.ld_row, output.ld_col,
         parameters, output_channel_end - output_channel_start
       );
@@ -531,7 +636,6 @@
     {
       // Otherwise, we repeatedly call the padded kernel but use our knowledge
       // of the tensor structure to avoid recomputing the pointer array.
-      const auto input_channel_start = output_channel_start / args.channel_multiplier;
 
       const auto n_input_pointers = this->m_strat->get_input_rows() * this->m_strat->get_input_cols();
       const auto input_point_stride = input.ld_col * this->m_strat->get_output_cols() * args.stride_cols;
@@ -543,16 +647,12 @@
       for (unsigned int tile_i = 0; tile_i < n_tile_rows; tile_i++)
       {
         const int input_i = static_cast<int>(output_i * args.stride_rows) - args.padding.top;
-        const int input_j = static_cast<int>(output_j * args.stride_cols) - args.padding.left;
+        int input_j = static_cast<int>(output_j * args.stride_cols) - args.padding.left;
 
-        fill_pointer_array<const TInput>(
-          ws->inptr_array, this->m_strat->get_input_rows(), this->m_strat->get_input_cols(),
-          input.base + input_i*input.ld_row + input_j*input.ld_col + input_channel_start,
-          input.ld_row, input.ld_col,
-          ws->input_buffer,
-          0, args.input_rows,
-          0, args.input_cols
-        );
+        Tile<TInput> multiplied_input;
+        this->initialise_inptr_array(args, output_channel_start, output_channel_end, input,
+          ws->inptr_array, ws->input_buffer, ws->intermediate_buffer,
+          input_i, input_j, 0, 0, multiplied_input);
 
         // Compute the output pointer array
         fill_pointer_array(
@@ -572,10 +672,18 @@
           );
 
           // Progress the pointers
-          for (auto i = 0u; i < n_input_pointers; i++)
-          {
-            ws->inptr_array[i] += input_point_stride;
+          if (this->uses_intermediate_array()) {
+            input_j += input_point_stride / input.ld_col;
+            multiplied_input.load_from(input.base,
+              input.ld_row, input.ld_col,
+              args.input_rows, args.input_cols, input_i, input_j, args.channel_multiplier);
+          } else {
+            for (auto i = 0u; i < n_input_pointers; i++)
+            {
+              ws->inptr_array[i] += input_point_stride;
+            }
           }
+
           for (auto i = 0u; i < n_output_pointers; i++)
           {
             ws->outptr_array[i] += output_point_stride;
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/depthwise_depthfirst_generic.hpp b/src/core/NEON/kernels/arm_conv/depthwise/depthwise_depthfirst_generic.hpp
index b058ce2..ca5026b 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/depthwise_depthfirst_generic.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/depthwise_depthfirst_generic.hpp
@@ -99,7 +99,7 @@
   {
     interleaves::PackingArguments packing_args(
       this->get_kernel_rows(), this->get_kernel_cols(), sizeof(TWeight),
-      false, sizeof(TAccum),  // Don't pack the bias
+      false, sizeof(TAccum), this->uses_premultiply(),  // Don't pack the bias
       this->get_vl_type(), sizeof(TAccum), this->get_accumulator_depth_vl(),
       [this] (unsigned int idx, unsigned int &x, unsigned int &y) -> bool
       { return this->get_kernel_packing_point(idx, x, y); }
@@ -115,7 +115,7 @@
   {
     interleaves::PackingArguments packing_args(
       this->get_kernel_rows(), this->get_kernel_cols(), sizeof(TWeight),
-      false, sizeof(TAccum),  // Don't pack the bias
+      false, sizeof(TAccum), this->uses_premultiply(),  // Don't pack the bias
       this->get_vl_type(), sizeof(TAccum), this->get_accumulator_depth_vl(),
       [this] (unsigned int idx, unsigned int &x, unsigned int &y) -> bool
       { return this->get_kernel_packing_point(idx, x, y); }
@@ -208,6 +208,7 @@
     OutputArrayElement<TOutput>,
     GenericInputArrayElement<TInput>,
     InputBufferElement<TInput>,
+    IntermediateBufferElement<TInput>,
     ActivationsElement<TAccum, OutputStage>
   >;
   using WorkingSpace = typename WorkspaceManager::WorkspaceType;
@@ -232,21 +233,38 @@
     depthwise_depthfirst::stash_bias(this->get_output_stage(), m_bias);
   }
 
-  size_t get_working_size_per_thread(const unsigned int n_input_channels) const override
+  size_t get_working_size_per_thread() const override
   {
     DepthwiseArgs args(this->m_args);
-    args.input_channels = n_input_channels;
     return WorkspaceManager::get_sizeof_workspace(WorkspaceArgs<IDepthfirstStrategy, OutputStage>(this->m_strat.get(), args, this->get_output_stage()));
   }
 
-  void initialise_working_space(void *buffer, unsigned int n_input_channels) const override
+  void initialise_working_space(void *buffer) const override
   {
     DepthwiseArgs args(this->m_args);
-    args.input_channels = n_input_channels;
     return WorkspaceManager::initialise(buffer, WorkspaceArgs<IDepthfirstStrategy, OutputStage>(this->m_strat.get(), args, this->get_output_stage()));
   }
 
   protected:
+  void fill_inptr_array(const DepthwiseArgs &args,
+    const TensorSpec<const TInput *> &input,
+    const TInput **inptr_array, TInput *input_buffer,
+    const unsigned int input_i, const unsigned int input_j,
+    const unsigned int input_pad_top, const unsigned int input_pad_left) const override
+  {
+    fill_pointer_array_generic_kernel<const TInput>(
+      inptr_array,
+      this->m_strat->get_output_rows(), this->m_strat->get_output_cols(),
+      args.kernel_rows, args.kernel_cols,
+      args.stride_rows, args.stride_cols,
+      input.base,
+      input.ld_row, input.ld_col,
+      input_buffer,
+      input_pad_top, args.input_rows - input_i,
+      input_pad_left, args.input_cols - input_j
+    );
+  }
+
   void compute_tile_padded(
     const DepthwiseArgs &args,
     unsigned int output_i, unsigned int output_j,
@@ -268,17 +286,10 @@
     const auto input_pad_left = static_cast<unsigned int>(ij < 0 ? -ij : 0);
     const auto input_j = static_cast<unsigned int>(ij < 0 ? 0 : ij);
 
-    fill_pointer_array_generic_kernel<const TInput>(
-      ws->inptr_array,
-      this->m_strat->get_output_rows(), this->m_strat->get_output_cols(),
-      args.kernel_rows, args.kernel_cols,
-      args.stride_rows, args.stride_cols,
-      input.base + input_i*input.ld_row + input_j*input.ld_col + channel_start,
-      input.ld_row, input.ld_col,
-      ws->input_buffer,
-      input_pad_top, args.input_rows - input_i,
-      input_pad_left, args.input_cols - input_j
-    );
+    Tile<TInput> multiplied_input;
+    this->initialise_inptr_array(args, channel_start, channel_end, input,
+      ws->inptr_array, ws->input_buffer, ws->intermediate_buffer,
+      input_i, input_j, input_pad_top, input_pad_left, multiplied_input);
 
     // Compute the output pointer array
     fill_pointer_array<TOutput>(
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/depthwise_depthfirst_multiplier.hpp b/src/core/NEON/kernels/arm_conv/depthwise/depthwise_depthfirst_multiplier.hpp
index 3d305b6..b93caa2 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/depthwise_depthfirst_multiplier.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/depthwise_depthfirst_multiplier.hpp
@@ -42,7 +42,7 @@
   {
     return interleaves::PackingArguments(
       args.kernel_rows, args.kernel_cols, sizeof(TWeight),
-      true, sizeof(TAccum),
+      true, sizeof(TAccum), this->uses_premultiply(),
       this->get_vl_type(),
       sizeof(TAccum), 1,
       [args] (unsigned int pos, unsigned int &x, unsigned int &y) -> bool
@@ -57,6 +57,10 @@
       }
     );
   }
+  
+  bool uses_premultiply() const override {
+    return false;
+  }
 
   public:
   using Parent::Parent;
@@ -192,7 +196,7 @@
   {
     return interleaves::PackingArguments(
       args.kernel_rows, args.kernel_cols, sizeof(TWeight),
-      false, sizeof(TAccum),
+      false, sizeof(TAccum), this->uses_premultiply(),
       this->get_vl_type(),
       sizeof(TAccum), 1,
       [args] (unsigned int pos, unsigned int &x, unsigned int &y) -> bool
@@ -207,6 +211,10 @@
       }
     );
   }
+  
+  bool uses_premultiply() const override {
+    return false;
+  }
 
   public:
   GenericDepthfirstMultiplierStrategy(KernelStrategyType *kern, const DepthwiseArgs &args)
@@ -483,6 +491,10 @@
   OutputStage m_os;  // Copy of the output parameters
   const void *m_bias = nullptr;  // Copy of the bias (should we need it)
 
+  bool uses_premultiply() const override {
+    return false;
+  }
+
   public:
   DepthwiseDepthfirstMultiplier(StratType *const strat, const DepthwiseArgs &args, const OutputStage &os = {})
   : DepthfirstDriver<TInput, TWeight, TOutput>(strat, args), m_os(os)
@@ -506,17 +518,15 @@
     depthwise_depthfirst::stash_bias(m_os, biases);
   }
 
-  size_t get_working_size_per_thread(const unsigned int n_input_channels) const override
+  size_t get_working_size_per_thread() const override
   {
     DepthwiseArgs args(this->m_args);
-    args.input_channels = n_input_channels;
     return WorkspaceManager::get_sizeof_workspace(WorkspaceArgs<IDepthfirstStrategy, OutputStage>(this->m_strat.get(), args, m_os));
   }
 
-  void initialise_working_space(void *buffer, unsigned int n_input_channels) const override
+  void initialise_working_space(void *buffer) const override
   {
     DepthwiseArgs args(this->m_args);
-    args.input_channels = n_input_channels;
     return WorkspaceManager::initialise(buffer, WorkspaceArgs<IDepthfirstStrategy, OutputStage>(this->m_strat.get(), args, m_os));
   }
 
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/depthwise_fp16.cpp b/src/core/NEON/kernels/arm_conv/depthwise/depthwise_fp16.cpp
index ed4f17d..3b76e52 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/depthwise_fp16.cpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/depthwise_fp16.cpp
@@ -28,6 +28,7 @@
 #include "depthwise_depthfirst.hpp"
 #include "depthwise_depthfirst_generic.hpp"
 #include "depthwise_depthfirst_multiplier.hpp"
+#include "depthwise_planar.hpp"
 
 #include "depthwise_implementation_constraints.hpp"
 
@@ -35,14 +36,14 @@
 #if defined(__ARM_FP16_ARGS)
 
 #if defined(__aarch64__)
-#if defined(ARM_COMPUTE_ENABLE_SVE)
 #if defined(ARM_COMPUTE_ENABLE_SME2)
 #include "kernels/sme2_fp16_nhwc_3x3_s1_output4x4_mla_depthfirst.hpp"
 #include "kernels/sme2_fp16_nhwc_3x3_s1_output3x3_mla_depthfirst.hpp"
 #include "kernels/sme2_fp16_nhwc_3x3_s1_output2x2_mla_depthfirst.hpp"
 #include "kernels/sme2_fp16_nhwc_3x3_s2_output2x2_mla_depthfirst.hpp"
 #include "kernels/sme2_fp16_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp"
-#endif // defined(ARM_COMPUTE_ENABLE_SME2)
+#endif  // defined(ARM_COMPUTE_ENABLE_SME2)
+#if defined(ARM_COMPUTE_ENABLE_SVE)
 #include "kernels/sve_fp16_nhwc_3x3_s1_output4x4_mla_depthfirst.hpp"
 #include "kernels/sve_fp16_nhwc_3x3_s1_output3x3_mla_depthfirst.hpp"
 #include "kernels/sve_fp16_nhwc_3x3_s1_output2x2_mla_depthfirst.hpp"
@@ -163,12 +164,11 @@
       return new DepthwiseDepthfirst<__fp16>(strat, args);
     },
   },
-#endif // defined(ARM_COMPUTE_ENABLE_SME2)
+#endif  // defined(ARM_COMPUTE_ENABLE_SME2)
   {
     DepthwiseMethod::DEPTHFIRST,
     "sve_fp16_nhwc_3x3_s1_output4x4_mla_depthfirst",
     constraint(is_supported<sve_fp16_nhwc_3x3_s1_output4x4_mla_depthfirst>,
-               has_no_channel_multiplier,
                cpu_has_sve),
     cycle_estimate<sve_fp16_nhwc_3x3_s1_output4x4_mla_depthfirst>,
     [] (const DepthwiseArgs &args, const Nothing &) -> DepthwiseCommon<__fp16, __fp16, __fp16> * {
@@ -180,7 +180,6 @@
     DepthwiseMethod::DEPTHFIRST,
     "sve_fp16_nhwc_3x3_s1_output3x3_mla_depthfirst",
     constraint(is_supported<sve_fp16_nhwc_3x3_s1_output3x3_mla_depthfirst>,
-               has_no_channel_multiplier,
                cpu_has_sve),
     cycle_estimate<sve_fp16_nhwc_3x3_s1_output3x3_mla_depthfirst>,
     [] (const DepthwiseArgs &args, const Nothing &) -> DepthwiseCommon<__fp16, __fp16, __fp16> * {
@@ -192,7 +191,6 @@
     DepthwiseMethod::DEPTHFIRST,
     "sve_fp16_nhwc_3x3_s1_output2x2_mla_depthfirst",
     constraint(is_supported<sve_fp16_nhwc_3x3_s1_output2x2_mla_depthfirst>,
-              has_no_channel_multiplier,
               cpu_has_sve),
     cycle_estimate<sve_fp16_nhwc_3x3_s1_output2x2_mla_depthfirst>,
     [] (const DepthwiseArgs &args, const Nothing &) -> DepthwiseCommon<__fp16, __fp16, __fp16> * {
@@ -204,7 +202,6 @@
     DepthwiseMethod::DEPTHFIRST,
     "sve_fp16_nhwc_3x3_s2_output2x2_mla_depthfirst",
     constraint(is_supported<sve_fp16_nhwc_3x3_s2_output2x2_mla_depthfirst>,
-               has_no_channel_multiplier,
                cpu_has_sve),
     cycle_estimate<sve_fp16_nhwc_3x3_s2_output2x2_mla_depthfirst>,
     [] (const DepthwiseArgs &args, const Nothing &) -> DepthwiseCommon<__fp16, __fp16, __fp16> * {
@@ -216,7 +213,6 @@
     DepthwiseMethod::DEPTHFIRST,
     "sve_fp16_nhwc_5x5_s1_output2x2_mla_depthfirst",
     constraint(is_supported<sve_fp16_nhwc_5x5_s1_output2x2_mla_depthfirst>,
-               has_no_channel_multiplier,
                cpu_has_sve),
     cycle_estimate<sve_fp16_nhwc_5x5_s1_output2x2_mla_depthfirst>,
     [] (const DepthwiseArgs &args, const Nothing &) -> DepthwiseCommon<__fp16, __fp16, __fp16> * {
@@ -229,7 +225,6 @@
     DepthwiseMethod::DEPTHFIRST,
     "a64_fp16_nhwc_3x3_s1_output4x4_mla_depthfirst",
     constraint(is_supported<a64_fp16_nhwc_3x3_s1_output4x4_mla_depthfirst>,
-               has_no_channel_multiplier,
                cpu_has_fp16),
     cycle_estimate<a64_fp16_nhwc_3x3_s1_output4x4_mla_depthfirst>,
     [] (const DepthwiseArgs &args, const Nothing &) -> DepthwiseCommon<__fp16, __fp16, __fp16> * {
@@ -241,7 +236,6 @@
     DepthwiseMethod::DEPTHFIRST,
     "a64_fp16_nhwc_3x3_s1_output3x3_mla_depthfirst",
     constraint(is_supported<a64_fp16_nhwc_3x3_s1_output3x3_mla_depthfirst>,
-               has_no_channel_multiplier,
                cpu_has_fp16),
     cycle_estimate<a64_fp16_nhwc_3x3_s1_output3x3_mla_depthfirst>,
     [] (const DepthwiseArgs &args, const Nothing &) -> DepthwiseCommon<__fp16, __fp16, __fp16> * {
@@ -253,7 +247,6 @@
     DepthwiseMethod::DEPTHFIRST,
     "a64_fp16_nhwc_3x3_s1_output2x2_mla_depthfirst",
     constraint(is_supported<a64_fp16_nhwc_3x3_s1_output2x2_mla_depthfirst>,
-               has_no_channel_multiplier,
                cpu_has_fp16),
     cycle_estimate<a64_fp16_nhwc_3x3_s1_output2x2_mla_depthfirst>,
     [] (const DepthwiseArgs &args, const Nothing &) -> DepthwiseCommon<__fp16, __fp16, __fp16> * {
@@ -265,7 +258,6 @@
     DepthwiseMethod::DEPTHFIRST,
     "a64_fp16_nhwc_3x3_s2_output2x2_mla_depthfirst",
     constraint(is_supported<a64_fp16_nhwc_3x3_s2_output2x2_mla_depthfirst>,
-               has_no_channel_multiplier,
                cpu_has_fp16),
     cycle_estimate<a64_fp16_nhwc_3x3_s2_output2x2_mla_depthfirst>,
     [] (const DepthwiseArgs &args, const Nothing &) -> DepthwiseCommon<__fp16, __fp16, __fp16> * {
@@ -277,7 +269,6 @@
     DepthwiseMethod::DEPTHFIRST,
     "a64_fp16_nhwc_5x5_s1_output2x2_mla_depthfirst",
     constraint(is_supported<a64_fp16_nhwc_5x5_s1_output2x2_mla_depthfirst>,
-               has_no_channel_multiplier,
                cpu_has_fp16),
     cycle_estimate<a64_fp16_nhwc_5x5_s1_output2x2_mla_depthfirst>,
     [] (const DepthwiseArgs &args, const Nothing &) -> DepthwiseCommon<__fp16, __fp16, __fp16> * {
@@ -288,7 +279,7 @@
   {
     DepthwiseMethod::DEPTHFIRST,
     "a64_fp16_nhwc_generic_output3x3_mla_depthfirst",
-    constraint(has_no_channel_multiplier, cpu_has_fp16),
+    constraint(cpu_has_fp16),
     not_preferred,
     [] (const DepthwiseArgs &args, const Nothing &) -> DepthwiseCommon<__fp16, __fp16, __fp16> * {
       auto kern = new a64_fp16_nhwc_generic_output9_mla_depthfirst(args.cpu_info);
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/depthwise_fp32.cpp b/src/core/NEON/kernels/arm_conv/depthwise/depthwise_fp32.cpp
index 382ccd3..9954be1 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/depthwise_fp32.cpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/depthwise_fp32.cpp
@@ -79,9 +79,45 @@
 
 namespace
 {
+  bool prefer_premultiply(const DepthwiseArgs &args) {
+    if ((args.stride_rows != args.stride_cols) || (args.kernel_rows != args.kernel_cols))
+    {
+      return false;
+    }
+
+    unsigned int threshold;
+
+    if (args.stride_rows == 1 && args.kernel_rows == 3)
+    {
+      threshold = 18;
+    }
+    else if (args.stride_rows == 1 && args.kernel_rows == 5)
+    {
+      threshold = 5;
+    }
+    else if (args.stride_rows == 2 && args.kernel_rows == 3)
+    {
+      threshold = 5;
+    }
+    else if (args.stride_rows == 2 && args.kernel_rows == 5)
+    {
+      threshold = 12;
+    } else
+    {
+      return false;
+    }
+
+    return args.channel_multiplier <= threshold;
+  }
+
   template <class Strategy>
   unsigned int cycle_estimate(const DepthwiseArgs &args, const Nothing &)
   {
+    if (args.channel_multiplier > 1 && !prefer_premultiply(args))
+    {
+      return UINT32_MAX;
+    }
+
     // First-pass: compute the number of output pixels which will be computed.
     return arm_gemm::roundup(args.output_rows, Strategy::output_rows) *
            arm_gemm::roundup(args.output_cols, Strategy::output_cols) *
@@ -116,6 +152,11 @@
   }
 
 #if defined(__aarch64__)
+  unsigned int multiplier_cycle_estimate(const DepthwiseArgs &args, const Nothing &)
+  {
+    return prefer_premultiply(args)? UINT32_MAX : 0;
+  }
+
   unsigned int not_preferred(const DepthwiseArgs &, const Nothing &)
   {
     return std::numeric_limits<unsigned int>::max();
@@ -246,8 +287,7 @@
     DepthwiseMethod::DEPTHFIRST,
     "sme2_fp32_nhwc_3x3_s1_output4x4_mla_depthfirst",
     constraint(cpu_has_sme,  cpu_has_sme2,
-               is_supported<sme2_fp32_nhwc_3x3_s1_output4x4_mla_depthfirst>,
-               has_no_channel_multiplier),
+               is_supported<sme2_fp32_nhwc_3x3_s1_output4x4_mla_depthfirst>),
     cycle_estimate<sme2_fp32_nhwc_3x3_s1_output4x4_mla_depthfirst>,
     [] (const DepthwiseArgs &args, const Nothing &) -> DepthwiseCommon<float, float, float> * {
       auto strat = new sme2_fp32_nhwc_3x3_s1_output4x4_mla_depthfirst(args.cpu_info);
@@ -258,8 +298,7 @@
     DepthwiseMethod::DEPTHFIRST,
     "sme2_fp32_nhwc_3x3_s1_output3x3_mla_depthfirst",
     constraint(cpu_has_sme, cpu_has_sme2,
-               is_supported<sme2_fp32_nhwc_3x3_s1_output3x3_mla_depthfirst>,
-               has_no_channel_multiplier),
+               is_supported<sme2_fp32_nhwc_3x3_s1_output3x3_mla_depthfirst>),
     cycle_estimate<sme2_fp32_nhwc_3x3_s1_output3x3_mla_depthfirst>,
     [] (const DepthwiseArgs &args, const Nothing &) -> DepthwiseCommon<float, float, float> * {
       auto strat = new sme2_fp32_nhwc_3x3_s1_output3x3_mla_depthfirst(args.cpu_info);
@@ -270,8 +309,7 @@
     DepthwiseMethod::DEPTHFIRST,
     "sme2_fp32_nhwc_3x3_s1_output2x2_mla_depthfirst",
     constraint(cpu_has_sme, cpu_has_sme2,
-               is_supported<sme2_fp32_nhwc_3x3_s1_output2x2_mla_depthfirst>,
-               has_no_channel_multiplier),
+               is_supported<sme2_fp32_nhwc_3x3_s1_output2x2_mla_depthfirst>),
     cycle_estimate<sme2_fp32_nhwc_3x3_s1_output2x2_mla_depthfirst>,
     [] (const DepthwiseArgs &args, const Nothing &) -> DepthwiseCommon<float, float, float> * {
       auto strat = new sme2_fp32_nhwc_3x3_s1_output2x2_mla_depthfirst(args.cpu_info);
@@ -282,8 +320,7 @@
     DepthwiseMethod::DEPTHFIRST,
     "sme2_fp32_nhwc_3x3_s2_output2x2_mla_depthfirst",
     constraint(cpu_has_sme, cpu_has_sme2,
-               is_supported<sme2_fp32_nhwc_3x3_s2_output2x2_mla_depthfirst>,
-               has_no_channel_multiplier),
+               is_supported<sme2_fp32_nhwc_3x3_s2_output2x2_mla_depthfirst>),
     cycle_estimate<sme2_fp32_nhwc_3x3_s2_output2x2_mla_depthfirst>,
     [] (const DepthwiseArgs &args, const Nothing &) -> DepthwiseCommon<float, float, float> * {
       auto strat = new sme2_fp32_nhwc_3x3_s2_output2x2_mla_depthfirst(args.cpu_info);
@@ -295,7 +332,6 @@
     DepthwiseMethod::DEPTHFIRST,
     "sve_fp32_nhwc_3x3_s1_output4x4_mla_depthfirst",
     constraint(is_supported<sve_fp32_nhwc_3x3_s1_output4x4_mla_depthfirst>,
-               has_no_channel_multiplier,
                cpu_has_sve),
     cycle_estimate<sve_fp32_nhwc_3x3_s1_output4x4_mla_depthfirst>,
     [] (const DepthwiseArgs &args, const Nothing &) -> DepthwiseCommon<float, float, float> * {
@@ -307,7 +343,6 @@
     DepthwiseMethod::DEPTHFIRST,
     "sve_fp32_nhwc_3x3_s1_output3x3_mla_depthfirst",
     constraint(is_supported<sve_fp32_nhwc_3x3_s1_output3x3_mla_depthfirst>,
-               has_no_channel_multiplier,
                cpu_has_sve),
     cycle_estimate<sve_fp32_nhwc_3x3_s1_output3x3_mla_depthfirst>,
     [] (const DepthwiseArgs &args, const Nothing &) -> DepthwiseCommon<float, float, float> * {
@@ -319,7 +354,6 @@
     DepthwiseMethod::DEPTHFIRST,
     "sve_fp32_nhwc_3x3_s1_output2x2_mla_depthfirst",
     constraint(is_supported<sve_fp32_nhwc_3x3_s1_output2x2_mla_depthfirst>,
-              has_no_channel_multiplier,
               cpu_has_sve),
     cycle_estimate<sve_fp32_nhwc_3x3_s1_output2x2_mla_depthfirst>,
     [] (const DepthwiseArgs &args, const Nothing &) -> DepthwiseCommon<float, float, float> * {
@@ -331,7 +365,6 @@
     DepthwiseMethod::DEPTHFIRST,
     "sve_fp32_nhwc_3x3_s2_output2x2_mla_depthfirst",
     constraint(is_supported<sve_fp32_nhwc_3x3_s2_output2x2_mla_depthfirst>,
-               has_no_channel_multiplier,
                cpu_has_sve),
     cycle_estimate<sve_fp32_nhwc_3x3_s2_output2x2_mla_depthfirst>,
     [] (const DepthwiseArgs &args, const Nothing &) -> DepthwiseCommon<float, float, float> * {
@@ -343,7 +376,6 @@
     DepthwiseMethod::DEPTHFIRST,
     "sve_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst",
     constraint(is_supported<sve_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst>,
-               has_no_channel_multiplier,
                cpu_has_sve),
     cycle_estimate<sve_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst>,
     [] (const DepthwiseArgs &args, const Nothing &) -> DepthwiseCommon<float, float, float> * {
@@ -354,7 +386,7 @@
   {
     DepthwiseMethod::DEPTHFIRST,
     "sve_fp32_nhwc_generic_output3x3_mla_depthfirst",
-    constraint(has_no_channel_multiplier, cpu_has_sve),
+    constraint(cpu_has_sve),
     not_preferred,
     [] (const DepthwiseArgs &args, const Nothing &) -> DepthwiseCommon<float, float, float> * {
       auto kern = new sve_fp32_nhwc_generic_output9_mla_depthfirst(args.cpu_info);
@@ -367,7 +399,7 @@
     "sve_fp32_nhwc_3x3_s2_with_multiplier_output3x3_mla_depthfirst",
     constraint(is_supported<sve_fp32_packed_to_nhwc_3x3_s2_with_multiplier_output3x3_mla_depthfirst>,
                cpu_has_sve, has_channel_multiplier),
-    nullptr,
+    multiplier_cycle_estimate,
     [] (const DepthwiseArgs &args, const Nothing &) -> DepthwiseCommon<float, float, float> * {
       auto strat = new sve_fp32_packed_to_nhwc_3x3_s2_with_multiplier_output3x3_mla_depthfirst(args.cpu_info);
       return new DepthwiseDepthfirstMultiplier<float>(strat, args);
@@ -378,7 +410,7 @@
     "sve_fp32_nhwc_5x5_s1_with_multiplier_output2x4_mla_depthfirst",
     constraint(is_supported<sve_fp32_packed_to_nhwc_5x5_s1_with_multiplier_output2x4_mla_depthfirst>,
                cpu_has_sve, has_channel_multiplier),
-    nullptr,
+    multiplier_cycle_estimate,
     [] (const DepthwiseArgs &args, const Nothing &) -> DepthwiseCommon<float, float, float> * {
       auto strat = new sve_fp32_packed_to_nhwc_5x5_s1_with_multiplier_output2x4_mla_depthfirst(args.cpu_info);
       return new DepthwiseDepthfirstMultiplier<float>(strat, args);
@@ -388,7 +420,7 @@
     DepthwiseMethod::DEPTHFIRST,
     "sve_fp32_nhwc_generic_with_multiplier_output2x8_mla_depthfirst",
     constraint(cpu_has_sve, has_channel_multiplier),
-    nullptr,
+    multiplier_cycle_estimate,
     [] (const DepthwiseArgs &args, const Nothing &) -> DepthwiseCommon<float, float, float> * {
       auto kern = new sve_fp32_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst(args.cpu_info);
       auto strat = new GenericDepthfirstMultiplierStrategy<float>(kern, args);
@@ -399,8 +431,7 @@
   {
     DepthwiseMethod::DEPTHFIRST,
     "a64_fp32_nhwc_3x3_s1_output4x4_mla_depthfirst",
-    constraint(is_supported<a64_fp32_nhwc_3x3_s1_output4x4_mla_depthfirst>,
-               has_no_channel_multiplier),
+    constraint(is_supported<a64_fp32_nhwc_3x3_s1_output4x4_mla_depthfirst>),
     cycle_estimate<a64_fp32_nhwc_3x3_s1_output4x4_mla_depthfirst>,
     [] (const DepthwiseArgs &args, const Nothing &) -> DepthwiseCommon<float, float, float> * {
       auto strat = new a64_fp32_nhwc_3x3_s1_output4x4_mla_depthfirst(args.cpu_info);
@@ -410,8 +441,7 @@
   {
     DepthwiseMethod::DEPTHFIRST,
     "a64_fp32_nhwc_3x3_s1_output3x3_mla_depthfirst",
-    constraint(is_supported<a64_fp32_nhwc_3x3_s1_output3x3_mla_depthfirst>,
-               has_no_channel_multiplier),
+    constraint(is_supported<a64_fp32_nhwc_3x3_s1_output3x3_mla_depthfirst>),
     cycle_estimate<a64_fp32_nhwc_3x3_s1_output3x3_mla_depthfirst>,
     [] (const DepthwiseArgs &args, const Nothing &) -> DepthwiseCommon<float, float, float> * {
       auto strat = new a64_fp32_nhwc_3x3_s1_output3x3_mla_depthfirst(args.cpu_info);
@@ -421,8 +451,7 @@
   {
     DepthwiseMethod::DEPTHFIRST,
     "a64_fp32_nhwc_3x3_s1_output2x2_mla_depthfirst",
-    constraint(is_supported<a64_fp32_nhwc_3x3_s1_output2x2_mla_depthfirst>,
-                            has_no_channel_multiplier),
+    constraint(is_supported<a64_fp32_nhwc_3x3_s1_output2x2_mla_depthfirst>),
     cycle_estimate<a64_fp32_nhwc_3x3_s1_output2x2_mla_depthfirst>,
     [] (const DepthwiseArgs &args, const Nothing &) -> DepthwiseCommon<float, float, float> * {
       auto strat = new a64_fp32_nhwc_3x3_s1_output2x2_mla_depthfirst(args.cpu_info);
@@ -432,8 +461,7 @@
   {
     DepthwiseMethod::DEPTHFIRST,
     "a64_fp32_nhwc_3x3_s2_output2x2_mla_depthfirst",
-    constraint(is_supported<a64_fp32_nhwc_3x3_s2_output2x2_mla_depthfirst>,
-               has_no_channel_multiplier),
+    constraint(is_supported<a64_fp32_nhwc_3x3_s2_output2x2_mla_depthfirst>),
     cycle_estimate<a64_fp32_nhwc_3x3_s2_output2x2_mla_depthfirst>,
     [] (const DepthwiseArgs &args, const Nothing &) -> DepthwiseCommon<float, float, float> * {
       auto strat = new a64_fp32_nhwc_3x3_s2_output2x2_mla_depthfirst(args.cpu_info);
@@ -443,8 +471,7 @@
   {
     DepthwiseMethod::DEPTHFIRST,
     "a64_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst",
-    constraint(is_supported<a64_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst>,
-               has_no_channel_multiplier),
+    constraint(is_supported<a64_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst>),
     cycle_estimate<a64_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst>,
     [] (const DepthwiseArgs &args, const Nothing &) -> DepthwiseCommon<float, float, float> * {
       auto strat = new a64_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst(args.cpu_info);
@@ -454,7 +481,7 @@
   {
     DepthwiseMethod::DEPTHFIRST,
     "a64_fp32_nhwc_generic_output3x3_mla_depthfirst",
-    constraint(has_no_channel_multiplier),
+    nullptr,
     not_preferred,
     [] (const DepthwiseArgs &args, const Nothing &) -> DepthwiseCommon<float, float, float> * {
       auto kern = new a64_fp32_nhwc_generic_output9_mla_depthfirst(args.cpu_info);
@@ -467,7 +494,7 @@
     "a64_fp32_nhwc_3x3_s2_with_multiplier_output3x3_mla_depthfirst",
     constraint(is_supported<a64_fp32_packed_to_nhwc_3x3_s2_with_multiplier_output3x3_mla_depthfirst>,
                has_channel_multiplier),
-    nullptr,
+    multiplier_cycle_estimate,
     [] (const DepthwiseArgs &args, const Nothing &) -> DepthwiseCommon<float, float, float> * {
       auto strat = new a64_fp32_packed_to_nhwc_3x3_s2_with_multiplier_output3x3_mla_depthfirst(args.cpu_info);
       return new DepthwiseDepthfirstMultiplier<float>(strat, args);
@@ -478,7 +505,7 @@
     "a64_fp32_nhwc_5x5_s1_with_multiplier_output2x4_mla_depthfirst",
     constraint(is_supported<a64_fp32_packed_to_nhwc_5x5_s1_with_multiplier_output2x4_mla_depthfirst>,
                has_channel_multiplier),
-    nullptr,
+    multiplier_cycle_estimate,
     [] (const DepthwiseArgs &args, const Nothing &) -> DepthwiseCommon<float, float, float> * {
       auto strat = new a64_fp32_packed_to_nhwc_5x5_s1_with_multiplier_output2x4_mla_depthfirst(args.cpu_info);
       return new DepthwiseDepthfirstMultiplier<float>(strat, args);
@@ -488,7 +515,7 @@
     DepthwiseMethod::DEPTHFIRST,
     "a64_fp32_nhwc_generic_with_multiplier_output2x8_mla_depthfirst",
     constraint(has_channel_multiplier),
-    nullptr,
+    multiplier_cycle_estimate,
     [] (const DepthwiseArgs &args, const Nothing &) -> DepthwiseCommon<float, float, float> * {
       auto kern = new a64_fp32_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst(args.cpu_info);
       auto strat = new GenericDepthfirstMultiplierStrategy<float>(kern, args);
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/depthwise_planar.hpp b/src/core/NEON/kernels/arm_conv/depthwise/depthwise_planar.hpp
index 567eab1..c3daaf0 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/depthwise_planar.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/depthwise_planar.hpp
@@ -153,7 +153,7 @@
   {
     return interleaves::PackingArguments(
       m_kernel_rows, m_kernel_cols, sizeof(TWeight),
-      false, sizeof(TAccum),  // Don't pack the bias
+      false, sizeof(TAccum), true,  // Don't pack the bias
       m_vl_type, sizeof(TAccum), 1,  // Accumulator depth of 1 TODO
       [this] (unsigned int idx, unsigned int &x, unsigned int &y) -> bool
       { return this->get_kernel_packing_point(idx, x, y); }
@@ -276,7 +276,7 @@
     depthwise_depthfirst::stash_bias(this->m_os, biases);
   }
 
-  size_t get_working_size(unsigned int n_threads, unsigned int) const override
+  size_t get_working_size(unsigned int n_threads) const override
   {
     return this->get_working_size_per_thread() * n_threads;
   }
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/depthwise_strategies_common.cpp b/src/core/NEON/kernels/arm_conv/depthwise/depthwise_strategies_common.cpp
index 33f2177..37892b6 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/depthwise_strategies_common.cpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/depthwise_strategies_common.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2022 Arm Limited.
+ * Copyright (c) 2022-2023 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -41,6 +41,8 @@
 unsigned int DepthfirstStrategyUntyped::get_n_output_points() const { return this->get_output_rows() * this->get_output_cols(); }
 unsigned int DepthfirstStrategyUntyped::get_n_kernel_points() const { return this->get_kernel_rows() * this->get_kernel_cols(); }
 
+bool DepthfirstStrategyUntyped::uses_premultiply() const { return true; }
+
 unsigned int DepthfirstStrategyUntyped::get_accumulator_depth_vl() const { return 1; }
 
 bool DepthfirstStrategyUntyped::get_kernel_packing_point(const unsigned int index, unsigned int &x, unsigned int &y) const
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/depthwise_strategies_common.hpp b/src/core/NEON/kernels/arm_conv/depthwise/depthwise_strategies_common.hpp
index 39f60c3..19cf26d 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/depthwise_strategies_common.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/depthwise_strategies_common.hpp
@@ -49,6 +49,8 @@
   virtual unsigned int get_n_output_points() const;
   virtual unsigned int get_n_kernel_points() const;
 
+  virtual bool uses_premultiply() const;
+
   // Get the number of VLs used in the accumulator, this defaults to 1.
   virtual unsigned int get_accumulator_depth_vl() const;
 
@@ -65,7 +67,7 @@
   {
     interleaves::PackingArguments packing_args(
       this->get_kernel_rows(), this->get_kernel_cols(), sizeof(TWeight),
-      true, sizeof(TAccum),
+      true, sizeof(TAccum), this->uses_premultiply(),
       this->get_vl_type(), sizeof(TAccum), this->get_accumulator_depth_vl(),
       [this] (unsigned int idx, unsigned int &x, unsigned int &y) -> bool
       { return this->get_kernel_packing_point(idx, x, y); }
@@ -81,7 +83,7 @@
   {
     interleaves::PackingArguments packing_args(
       this->get_kernel_rows(), this->get_kernel_cols(), sizeof(TWeight),
-      true, sizeof(TAccum),
+      true, sizeof(TAccum), this->uses_premultiply(),
       this->get_vl_type(), sizeof(TAccum), this->get_accumulator_depth_vl(),
       [this] (unsigned int idx, unsigned int &x, unsigned int &y) -> bool
       { return this->get_kernel_packing_point(idx, x, y); }
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/interleaves/a64_s8q_3x3_dot.cpp b/src/core/NEON/kernels/arm_conv/depthwise/interleaves/a64_s8q_3x3_dot.cpp
index 5e4bf99..3de4bdc 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/interleaves/a64_s8q_3x3_dot.cpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/interleaves/a64_s8q_3x3_dot.cpp
@@ -42,7 +42,7 @@
 {
   // We store 7 vectors for every <vector_of_ints> of channels.
   const unsigned int n = arm_gemm::roundup(
-    arm_gemm::iceildiv((long unsigned int) args.input_channels,
+    arm_gemm::iceildiv((long unsigned int) args.input_channels * args.channel_multiplier,
                        get_vector_length<int32_t>(arm_gemm::VLType::None)), 4lu
   );
   return n * 7 * get_vector_length<int8_t>(arm_gemm::VLType::None);
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/interleaves/generic.cpp b/src/core/NEON/kernels/arm_conv/depthwise/interleaves/generic.cpp
index 056f08d..dc505a0 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/interleaves/generic.cpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/interleaves/generic.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2022 Arm Limited.
+ * Copyright (c) 2022-2023 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -32,11 +32,11 @@
 
 PackingArguments::PackingArguments(
   unsigned int kernel_rows, unsigned int kernel_cols, size_t weight_element_size,
-  bool include_bias, size_t bias_element_size,
+  bool include_bias, size_t bias_element_size, bool premultiply,
   arm_gemm::VLType vl_type, size_t accumulator_element_size, unsigned int accumulator_depth_vl,
   std::function<bool(unsigned int, unsigned int &, unsigned int &)> get_weight_pos
 ) : kernel_rows(kernel_rows), kernel_cols(kernel_cols), weight_element_size(weight_element_size),
-    include_bias(include_bias), bias_element_size(bias_element_size),
+    include_bias(include_bias), bias_element_size(bias_element_size), premultiply(premultiply),
     vl_type(vl_type), accumulator_element_size(accumulator_element_size), accumulator_depth_vl(accumulator_depth_vl),
     get_weight_pos(get_weight_pos)
 {
@@ -46,7 +46,7 @@
 {
   // If the channel multiplier is greater than one, then we treat this as a
   // repeated packing of `channel_multiplier`-sized problems.
-  if (args.channel_multiplier > 1)
+  if (args.channel_multiplier > 1 && !packing_args.premultiply)
   {
     DepthwiseArgs args_per_input_channel(args);
     args_per_input_channel.input_channels = args.channel_multiplier;
@@ -58,7 +58,7 @@
   const unsigned int vl =
     packing_args.accumulator_depth_vl *
     arm_gemm::utils::get_vector_length<uint8_t>(packing_args.vl_type) / packing_args.accumulator_element_size;
-  const unsigned int n_packs = arm_gemm::iceildiv(args.input_channels, vl);
+  const unsigned int n_packs = arm_gemm::iceildiv(args.input_channels * args.channel_multiplier, vl);
   const auto pack_size = (packing_args.include_bias ? packing_args.bias_element_size : 0) +
                          packing_args.kernel_points() * packing_args.weight_element_size;
   return n_packs * pack_size * vl;
@@ -81,7 +81,7 @@
 
   // If the channel multiplier is greater than one, then we treat this as a
   // repeated packing of `channel_multiplier`-sized problems.
-  if (args.channel_multiplier > 1)
+  if (args.channel_multiplier > 1 && !packing_args.premultiply)
   {
     // Get a modified copy of the depthwise arguments
     DepthwiseArgs args_per_input_channel(args);
@@ -107,17 +107,19 @@
     return;
   }
 
+  auto input_channels = args.input_channels * args.channel_multiplier;
+
   // Finalise the weight strides
-  ld_weight_col = (ld_weight_col == 0) ? args.input_channels : ld_weight_col;
+  ld_weight_col = (ld_weight_col == 0) ? input_channels : ld_weight_col;
   ld_weight_row = (ld_weight_row == 0) ? packing_args.kernel_cols * ld_weight_col : ld_weight_row;
 
   const unsigned int vl =
     packing_args.accumulator_depth_vl *
     arm_gemm::utils::get_vector_length<uint8_t>(packing_args.vl_type) / packing_args.accumulator_element_size;
 
-  for (unsigned int n = 0; n < args.input_channels; n += vl)
+  for (unsigned int n = 0; n < input_channels; n += vl)
   {
-    const unsigned int todo = std::min(vl, args.input_channels - n);
+    const unsigned int todo = std::min(vl, input_channels - n);
 
     if (packing_args.include_bias)
     {
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/interleaves/generic.hpp b/src/core/NEON/kernels/arm_conv/depthwise/interleaves/generic.hpp
index 756c50b..1842f10 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/interleaves/generic.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/interleaves/generic.hpp
@@ -40,6 +40,7 @@
   const size_t weight_element_size;
   const bool include_bias;
   const size_t bias_element_size;
+  const bool premultiply;
   arm_gemm::VLType vl_type;
   const size_t accumulator_element_size;
   const unsigned int accumulator_depth_vl;
@@ -53,6 +54,7 @@
     size_t weight_element_size,
     bool include_bias,
     size_t bias_element_size,
+    bool premultiply,
     arm_gemm::VLType vl_type,
     size_t accumulator_element_size,
     unsigned int accumulator_depth_vl,
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_s8q_nhwc_3x3_s1_output2x2_dot_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_s8q_nhwc_3x3_s1_output2x2_dot_depthfirst.hpp
index 85053b3..2b97ad8 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_s8q_nhwc_3x3_s1_output2x2_dot_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_s8q_nhwc_3x3_s1_output2x2_dot_depthfirst.hpp
@@ -64,7 +64,7 @@
   ) const override
   {
     interleave_a64_s8q_3x3_dot::pack_parameters(
-      args.input_channels, buffer, reinterpret_cast<const int32_t *>(biases),
+      args.input_channels * args.channel_multiplier, buffer, reinterpret_cast<const int32_t *>(biases),
       reinterpret_cast<const int8_t *>(weights), qp, ld_weight_col, ld_weight_row
     );
   }
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/premultiply.cpp b/src/core/NEON/kernels/arm_conv/depthwise/premultiply.cpp
new file mode 100644
index 0000000..ad4c821
--- /dev/null
+++ b/src/core/NEON/kernels/arm_conv/depthwise/premultiply.cpp
@@ -0,0 +1,70 @@
+/*
+ * Copyright (c) 2023 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include <premultiply.hpp>
+
+#define CHANNEL_MULTIPLIER 6
+#define BLOCK_SIZE 4
+
+void do_premultiply_float_6(const float       *in_ptr,
+                            const unsigned int ld_row,
+                            const unsigned int ld_col,
+                            float             *out_ptr,
+                            const unsigned int out_ld_row,
+                            const unsigned int out_ld_col,
+                            const unsigned int tile_rows,
+                            const unsigned int tile_cols,
+                            const unsigned     input_channels)
+{
+    for(unsigned int i = 0; i < tile_rows; i++)
+    {
+        const float *ip2 = in_ptr + i * ld_row;
+        float       *op2 = out_ptr + i * out_ld_row;
+        for(unsigned int j = 0; j < tile_cols; j++)
+        {
+            const float *ip = ip2;
+            float       *op = op2;
+            for(unsigned int c = 0; c < input_channels; c += BLOCK_SIZE)
+            {
+                float vals[BLOCK_SIZE];
+                for(unsigned int v = 0; v < BLOCK_SIZE; v++)
+                {
+                    vals[v] = ip[v];
+                }
+                ip += BLOCK_SIZE;
+
+                for(unsigned int v = 0; v < BLOCK_SIZE; v++)
+                {
+                    for(unsigned int r = 0; r < CHANNEL_MULTIPLIER; r++)
+                    {
+                        op[r] = vals[v];
+                    }
+                    op += CHANNEL_MULTIPLIER;
+                }
+            }
+            ip2 += ld_col;
+            op2 += out_ld_col;
+        }
+    }
+}
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/working_space.hpp b/src/core/NEON/kernels/arm_conv/depthwise/working_space.hpp
index b1fe66c..9805fd3 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/working_space.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/working_space.hpp
@@ -217,7 +217,7 @@
   template <typename StratType, typename OutputStage>
   static size_t get_element_size(const WorkspaceArgs<StratType, OutputStage> &args)
   {
-    return sizeof(T) * args.depthwise_args.input_channels;
+    return sizeof(T) * args.depthwise_args.input_channels * args.depthwise_args.channel_multiplier;
   }
 
   template <class WorkspaceType, typename StratType, typename OutputStage>
@@ -278,6 +278,36 @@
 };
 
 
+/* Intermediate array to store results of premultiplication.
+ * Used as input to the kernel instead of the original input array.
+ */
+template <typename T>
+class IntermediateBufferElement
+{
+public:
+    struct Workspace
+    {
+        T *intermediate_buffer;
+    };
+
+    template <typename StratType, typename OutputStage>
+    static size_t get_element_size(const WorkspaceArgs<StratType, OutputStage> &args)
+    {
+      auto cols = args.depthwise_args.input_cols + args.depthwise_args.kernel_cols;
+      auto rows = args.strategy->get_input_rows() + args.depthwise_args.kernel_rows;
+      auto channels = args.depthwise_args.input_channels * args.depthwise_args.channel_multiplier;
+      return sizeof(T) * cols * rows * channels;
+    }
+
+    template <class WorkspaceType, typename StratType, typename OutputStage>
+    static void *initialise(WorkspaceType *ws, void *buffer, const WorkspaceArgs<StratType, OutputStage> &args)
+    {
+      ws->intermediate_buffer = reinterpret_cast<T*>(buffer);
+      return reinterpret_cast<char *>(buffer) + get_element_size(args);
+    }
+};
+
+
 /* Container for requantization parameters.
  *
  * This removes the distinction between per-layer and per-channel
diff --git a/src/core/NEON/kernels/arm_conv/pooling/depthfirst_driver.hpp b/src/core/NEON/kernels/arm_conv/pooling/depthfirst_driver.hpp
index b0aa62b..d0e8639 100644
--- a/src/core/NEON/kernels/arm_conv/pooling/depthfirst_driver.hpp
+++ b/src/core/NEON/kernels/arm_conv/pooling/depthfirst_driver.hpp
@@ -64,10 +64,10 @@
   std::unique_ptr<const IDepthfirstStrategy> m_strat;
 
   /* Compute the amount of working space required for a single thread. */
-  virtual size_t get_working_size_per_thread(unsigned int n_input_channels) const = 0;
+  virtual size_t get_working_size_per_thread() const = 0;
 
   /* Initialise the working space for a thread. */
-  virtual void initialise_working_space(void *, unsigned int n_input_channels) const = 0;
+  virtual void initialise_working_space(void *) const = 0;
 
   /* Compute a portion of the output tensor with padding. */
   virtual void compute_tile_padded(
@@ -148,8 +148,8 @@
   {
     // Get and initialise the working space for this thread.
     void *thread_working_space =
-      static_cast<uint8_t *>(working_space) + thread_id * this->get_working_size_per_thread(n_channels);
-    this->initialise_working_space(thread_working_space, n_channels);
+      static_cast<uint8_t *>(working_space) + thread_id * this->get_working_size_per_thread();
+    this->initialise_working_space(thread_working_space);
 
     // Construct convenient representations of the input/output tensors.
     TensorSpec<const TInput *> input_tensor(reinterpret_cast<const TInput *>(input), ld_input_row, ld_input_col);
@@ -289,14 +289,9 @@
   {
   }
 
-  size_t get_working_size(unsigned int n_threads) const override
+  size_t get_working_size(unsigned int n_threads) const override final
   {
-    return this->get_working_size(n_threads, this->m_args.n_channels);
-  }
-
-  size_t get_working_size(unsigned int n_threads, unsigned int n_channels) const override final
-  {
-    return n_threads * this->get_working_size_per_thread(n_channels);
+    return n_threads * this->get_working_size_per_thread();
   }
 };
 
diff --git a/src/core/NEON/kernels/arm_conv/pooling/pooling_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/pooling/pooling_depthfirst.hpp
index 8a6e63d..1ca4785 100644
--- a/src/core/NEON/kernels/arm_conv/pooling/pooling_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/pooling/pooling_depthfirst.hpp
@@ -91,17 +91,17 @@
 
   protected:
   /* Compute the amount of working space required for a single thread. */
-  size_t get_working_size_per_thread(unsigned int n_channels) const override
+  size_t get_working_size_per_thread() const override
   {
-    return sizeof(WorkingSpace) + n_channels * (sizeof(TInput) + sizeof(TOutput));
+    return sizeof(WorkingSpace) + this->m_args.n_channels * (sizeof(TInput) + sizeof(TOutput));
   }
 
   /* Initialise the working space for a thread. */
-  void initialise_working_space(void *raw_ws, unsigned int n_channels) const override
+  void initialise_working_space(void *raw_ws) const override
   {
     auto ws = reinterpret_cast<WorkingSpace *>(raw_ws);
     ws->input_buffer = ws + 1;
-    ws->output_buffer = reinterpret_cast<char *>(ws + 1) + sizeof(TInput) * n_channels;
+    ws->output_buffer = reinterpret_cast<char *>(ws + 1) + sizeof(TInput) * this->m_args.n_channels;
 
     // Fill the input buffer with an appropriate value
     TInput fill_val = 0;
@@ -119,6 +119,7 @@
     }
 
     auto ptr = reinterpret_cast<TInput *>(ws->input_buffer);
+    auto n_channels = this->m_args.n_channels;
     for (; n_channels; n_channels--)
     {
       *(ptr++) = fill_val;
diff --git a/src/core/NEON/kernels/arm_conv/pooling/pooling_depthfirst_generic.hpp b/src/core/NEON/kernels/arm_conv/pooling/pooling_depthfirst_generic.hpp
index 07c5820..ded2c75 100644
--- a/src/core/NEON/kernels/arm_conv/pooling/pooling_depthfirst_generic.hpp
+++ b/src/core/NEON/kernels/arm_conv/pooling/pooling_depthfirst_generic.hpp
@@ -136,8 +136,8 @@
   const OutputStage m_os;
 
   protected:
-  size_t get_working_size_per_thread(unsigned int) const override { return 0; }
-  void initialise_working_space(void *, unsigned int) const override { /* Nothing */ }
+  size_t get_working_size_per_thread() const override { return 0; }
+  void initialise_working_space(void *) const override { /* Nothing */ }
 
   /* Compute a portion of the output tensor with padding. */
   void compute_tile_padded(
diff --git a/src/core/NEON/kernels/assembly/depthwise.hpp b/src/core/NEON/kernels/assembly/depthwise.hpp
index 8eb278c..dbd47cc 100644
--- a/src/core/NEON/kernels/assembly/depthwise.hpp
+++ b/src/core/NEON/kernels/assembly/depthwise.hpp
@@ -27,6 +27,7 @@
 #include "arm_gemm.hpp"
 #include "arm_gemm_local.hpp"
 #include "depthwise_common.hpp"
+#include "premultiply.hpp"
 
 namespace arm_conv
 {
@@ -38,8 +39,8 @@
     std::string     filter = "";
 
     DepthwiseConfig(DepthwiseMethod method)
-        : method(method){};
-    DepthwiseConfig(){};
+        : method(method) {};
+    DepthwiseConfig() {};
 };
 
 struct DepthwiseArgs
@@ -112,17 +113,64 @@
     }
 };
 
+template <typename TInput>
+struct Tile
+{
+    TInput *array;
+
+    unsigned int tile_rows     = 0;
+    unsigned int tile_cols     = 0;
+    unsigned int tile_channels = 0;
+
+    Tile(TInput *array, unsigned int tile_rows, unsigned int tile_cols, unsigned int tile_channels)
+        : array(array), tile_rows(tile_rows), tile_cols(tile_cols), tile_channels(tile_channels)
+    {
+    }
+
+    Tile()
+        : Tile(nullptr, 0, 0, 0)
+    {
+    }
+
+    void load_from(
+        const TInput      *input,
+        const unsigned int ld_row, const unsigned int ld_col,
+        const unsigned int n_rows, const unsigned int n_cols,
+        const int input_i, const int input_j,
+        const unsigned int channel_multiplier) const
+    {
+        const auto pad_top  = input_i < 0 ? -input_i : 0;
+        const auto pad_left = input_j < 0 ? -input_j : 0;
+
+        const auto padded_rows = std::min(n_rows - input_i, tile_rows) - pad_top;
+        const auto padded_cols = std::min(n_cols - input_j, tile_cols) - pad_left;
+
+        if(padded_rows < tile_rows || padded_cols < tile_cols)
+        {
+            memset(array, 0, tile_rows * tile_cols * tile_channels * sizeof(TInput));
+        }
+
+        do_premultiply<TInput>(
+            (TInput *)input + std::max(input_i, 0) * ld_row + std::max(input_j, 0) * ld_col,
+            ld_row, ld_col,
+            array + pad_top * tile_cols * tile_channels + pad_left * tile_channels,
+            tile_cols * tile_channels, tile_channels,
+            padded_rows, padded_cols, tile_channels / channel_multiplier,
+            channel_multiplier);
+    }
+};
+
 template <typename TInput, typename TWeight, typename TOutput>
 class DepthwiseCommon : public IDepthwiseCommon
 {
-    protected:
+protected:
     const DepthwiseArgs m_args; // Copy of arguments
     std::string         m_name{};
 
-    public:
+public:
     DepthwiseCommon(const DepthwiseArgs &args)
-        : m_args(args){};
-    DepthwiseCommon(DepthwiseCommon &)            = delete;
+        : m_args(args) {};
+    DepthwiseCommon(DepthwiseCommon &) = delete;
     DepthwiseCommon &operator=(DepthwiseCommon &) = delete;
 
     std::string name() const override
@@ -133,7 +181,7 @@
     void set_name(std::string name)
     {
         // Only allow the name to be set once
-        if (m_name.empty())
+        if(m_name.empty())
         {
             m_name = name;
         }
@@ -209,47 +257,47 @@
         // passed different input/output tensors. Dilation is handled at this
         // level; so we set the dilation in the arguments to zero.
         DepthwiseArgs args(this->m_args);
-        args.n_batches = batches;
-        args.input_rows = input_height;
-        args.input_cols = input_width;
+        args.n_batches      = batches;
+        args.input_rows     = input_height;
+        args.input_cols     = input_width;
         args.input_channels = channels;
-        args.output_rows = output_height;
-        args.output_cols = output_width;
-        args.padding = padding;
+        args.output_rows    = output_height;
+        args.output_cols    = output_width;
+        args.padding        = padding;
         args.dilation_rows = args.dilation_cols = 1;
 
-        auto ld_input_col_d = ld_input_col * m_args.dilation_cols;
-        auto ld_input_row_d = ld_input_row * m_args.dilation_rows;
+        auto ld_input_col_d  = ld_input_col * m_args.dilation_cols;
+        auto ld_input_row_d  = ld_input_row * m_args.dilation_rows;
         auto ld_output_col_d = ld_output_col * m_args.dilation_cols;
         auto ld_output_row_d = ld_output_row * m_args.dilation_rows;
 
-        for (size_t drow = 0; drow < m_args.dilation_rows; drow++)
+        for(size_t drow = 0; drow < m_args.dilation_rows; drow++)
         {
             size_t start_i;
             std::tie(args.output_rows, args.input_rows, start_i,
                      args.padding.top, args.padding.bottom) =
-                get_reduced_view_for_dilation(
-                        output_height, input_height, drow, m_args.dilation_rows,
-                        m_args.kernel_rows, m_args.stride_rows, padding.top);
+                         get_reduced_view_for_dilation(
+                             output_height, input_height, drow, m_args.dilation_rows,
+                             m_args.kernel_rows, m_args.stride_rows, padding.top);
 
-            auto input_row = static_cast<const TInput *>(input) + start_i * ld_input_row;
+            auto input_row  = static_cast<const TInput *>(input) + start_i * ld_input_row;
             auto output_row = static_cast<TOutput *>(output) + drow * ld_output_row;
 
-            if (args.output_rows)
+            if(args.output_rows)
             {
-                for (size_t dcol = 0; dcol < m_args.dilation_cols; dcol++)
+                for(size_t dcol = 0; dcol < m_args.dilation_cols; dcol++)
                 {
                     size_t start_j;
                     std::tie(args.output_cols, args.input_cols, start_j,
                              args.padding.left, args.padding.right) =
-                        get_reduced_view_for_dilation(
-                                output_width, input_width, dcol, m_args.dilation_cols,
-                                m_args.kernel_cols, m_args.stride_cols, padding.left);
+                                 get_reduced_view_for_dilation(
+                                     output_width, input_width, dcol, m_args.dilation_cols,
+                                     m_args.kernel_cols, m_args.stride_cols, padding.left);
 
-                    const TInput *input_col = input_row + start_j * ld_input_col;
-                    TOutput *output_col = output_row + dcol * ld_output_col;
+                    const TInput *input_col  = input_row + start_j * ld_input_col;
+                    TOutput      *output_col = output_row + dcol * ld_output_col;
 
-                    if (args.output_cols)
+                    if(args.output_cols)
                     {
                         this->execute_internal(
                             args, input_col, ld_input_col_d, ld_input_row_d, ld_input_batch, parameters,
@@ -261,7 +309,7 @@
         }
     }
 
-    protected:
+protected:
     virtual void execute_internal(
         const DepthwiseArgs &instance_args,
         const void          *input,
@@ -276,6 +324,11 @@
         void                *working_space,
         unsigned int         thread_id,
         unsigned int         n_threads) const = 0;
+
+    virtual bool uses_premultiply() const
+    {
+        return true;
+    }
 };
 
 template <typename TInput, typename TWeight = TInput, typename TOutput = TInput>
diff --git a/src/core/NEON/kernels/assembly/depthwise_common.hpp b/src/core/NEON/kernels/assembly/depthwise_common.hpp
index fea6326..a5db793 100644
--- a/src/core/NEON/kernels/assembly/depthwise_common.hpp
+++ b/src/core/NEON/kernels/assembly/depthwise_common.hpp
@@ -85,7 +85,7 @@
         size_t      ld_weight_row = 0) = 0;
 
     // Determine the amount of working space required
-    virtual size_t get_working_size(unsigned int n_threads, unsigned int n_input_channels) const = 0;
+    virtual size_t get_working_size(unsigned int n_threads) const = 0;
 
     // Execute the convolution over the specified area of memory.
     virtual void execute(
diff --git a/src/core/NEON/kernels/assembly/pool_common.hpp b/src/core/NEON/kernels/assembly/pool_common.hpp
index 599e18a..f1f70cf 100644
--- a/src/core/NEON/kernels/assembly/pool_common.hpp
+++ b/src/core/NEON/kernels/assembly/pool_common.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021-2022 Arm Limited.
+ * Copyright (c) 2021-2023 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -66,7 +66,6 @@
 
     // Determine the amount of working space required.
     virtual size_t get_working_size(unsigned int num_threads) const = 0;
-    virtual size_t get_working_size(unsigned int num_threads, unsigned int n_channels) const = 0;
 
     // Execute pooling over the specified area of memory.
     virtual void execute(
diff --git a/src/core/NEON/kernels/assembly/pooling.hpp b/src/core/NEON/kernels/assembly/pooling.hpp
index 1b47853..e8db35c 100644
--- a/src/core/NEON/kernels/assembly/pooling.hpp
+++ b/src/core/NEON/kernels/assembly/pooling.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021-2022 Arm Limited.
+ * Copyright (c) 2021-2023 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -122,11 +122,7 @@
     PoolingCommon(PoolingCommon &) = delete;
     PoolingCommon &operator=(PoolingCommon &) = delete;
 
-    size_t get_working_size(unsigned int, unsigned int) const override = 0;
-    size_t get_working_size(unsigned int n_threads) const override
-    {
-        return this->get_working_size(n_threads, m_args.n_channels);
-    }
+    size_t get_working_size(unsigned int) const override = 0;
 
     // Execute pooling over the specified area of memory.
     void execute(
diff --git a/src/core/NEON/kernels/assembly/premultiply.hpp b/src/core/NEON/kernels/assembly/premultiply.hpp
new file mode 100644
index 0000000..16f26de
--- /dev/null
+++ b/src/core/NEON/kernels/assembly/premultiply.hpp
@@ -0,0 +1,81 @@
+/*
+ * Copyright (c) 2023 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+void do_premultiply_float_6(const float       *in_ptr,
+                            const unsigned int ld_row,
+                            const unsigned int ld_col,
+                            float             *out_ptr,
+                            const unsigned int out_ld_row,
+                            const unsigned int out_ld_col,
+                            const unsigned int tile_rows,
+                            const unsigned int tile_cols,
+                            const unsigned     input_channels);
+
+template <typename T>
+void do_premultiply(const T           *in_ptr,
+                    const unsigned int ld_row,
+                    const unsigned int ld_col,
+                    T                 *out_ptr,
+                    const unsigned int out_ld_row,
+                    const unsigned int out_ld_col,
+                    const unsigned int tile_rows,
+                    const unsigned int tile_cols,
+                    const unsigned     input_channels,
+                    const unsigned int channel_multiplier)
+{
+    if(sizeof(T) == 4 && channel_multiplier == 6)
+    {
+        do_premultiply_float_6(
+            (const float *)in_ptr, ld_row, ld_col,
+            (float *)out_ptr, out_ld_row, out_ld_col,
+            tile_rows, tile_cols,
+            input_channels);
+    }
+    else
+    {
+        for(unsigned int i = 0; i < tile_rows; i++)
+        {
+            const T *ip2 = in_ptr + i * ld_row;
+            T       *op2 = out_ptr + i * out_ld_row;
+            for(unsigned int j = 0; j < tile_cols; j++)
+            {
+                const T *ip = ip2;
+                T       *op = op2;
+                for(unsigned int c = 0; c < input_channels; c++)
+                {
+                    T val = *ip;
+                    ip++;
+
+                    for(unsigned int r = 0; r < channel_multiplier; r++)
+                    {
+                        op[r] = val;
+                    }
+                    op += channel_multiplier;
+                }
+                ip2 += ld_col;
+                op2 += out_ld_col;
+            }
+        }
+    }
+}
diff --git a/src/cpu/kernels/internal/CpuDepthwiseConv2dAssemblyWrapperKernel.cpp b/src/cpu/kernels/internal/CpuDepthwiseConv2dAssemblyWrapperKernel.cpp
index 8cda5c6..e092c83 100644
--- a/src/cpu/kernels/internal/CpuDepthwiseConv2dAssemblyWrapperKernel.cpp
+++ b/src/cpu/kernels/internal/CpuDepthwiseConv2dAssemblyWrapperKernel.cpp
@@ -363,9 +363,9 @@
     return _kernel_asm->get_storage_size();
 }
 
-size_t CpuDepthwiseConv2dAssemblyWrapperKernel::get_working_size(unsigned int num_threads, unsigned int num_input_channels) const
+size_t CpuDepthwiseConv2dAssemblyWrapperKernel::get_working_size(unsigned int num_threads) const
 {
-    return _kernel_asm->get_working_size(num_threads, num_input_channels);
+    return _kernel_asm->get_working_size(num_threads);
 }
 
 bool CpuDepthwiseConv2dAssemblyWrapperKernel::is_configured() const
diff --git a/src/cpu/kernels/internal/CpuDepthwiseConv2dAssemblyWrapperKernel.h b/src/cpu/kernels/internal/CpuDepthwiseConv2dAssemblyWrapperKernel.h
index 16d3b21..f61cb1b 100644
--- a/src/cpu/kernels/internal/CpuDepthwiseConv2dAssemblyWrapperKernel.h
+++ b/src/cpu/kernels/internal/CpuDepthwiseConv2dAssemblyWrapperKernel.h
@@ -98,12 +98,11 @@
 
     /** Get size of the workspace needed by the assembly kernel.
      *
-     * @param[in] num_threads        Maximum number of threads that are going to be spawned.
-     * @param[in] num_input_channels Number of channels of the input tensor.
+     * @param[in] num_threads Maximum number of threads that are going to be spawned.
      *
      * @return size of workspace
      */
-    size_t get_working_size(unsigned int num_threads, unsigned int num_input_channels) const;
+    size_t get_working_size(unsigned int num_threads) const;
 
     /** Was the asm kernel successfully configured?
      *
diff --git a/src/cpu/operators/CpuDepthwiseConv2dAssemblyDispatch.cpp b/src/cpu/operators/CpuDepthwiseConv2dAssemblyDispatch.cpp
index a5b9eca..d078155 100644
--- a/src/cpu/operators/CpuDepthwiseConv2dAssemblyDispatch.cpp
+++ b/src/cpu/operators/CpuDepthwiseConv2dAssemblyDispatch.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2019-2022 Arm Limited.
+ * Copyright (c) 2019-2023 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -77,7 +77,7 @@
 
     // Compute memory requirements for assembly kernels
     constexpr size_t alignment = 4096;
-    _pImpl->mem_req.push_back({ TensorType::ACL_INT_0, dwc_wrapper->get_working_size(num_threads, src->dimension(0)), alignment });
+    _pImpl->mem_req.push_back({ TensorType::ACL_INT_0, dwc_wrapper->get_working_size(num_threads), alignment });
     _pImpl->mem_req.push_back({ TensorType::ACL_INT_1, dwc_wrapper->get_storage_size(), alignment });
     _pImpl->asm_kernel = std::move(dwc_wrapper);
 }