Update Neon™ depthwise kernel

- Reduce duplication and simplify overall structure.
- Improve multi-threaded performance by sharing more data
  in lower-level caches.

Partially Resolves: COMPMID-5054
Signed-off-by: Ramy Elgammal <ramy.elgammal@arm.com>
Change-Id: Iac747f39b21c540122fa75218762631c4d787911
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/7449
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Andrew Mundy
Reviewed-by: Sheri Zhang <sheri.zhang@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/depthwise_depthfirst_multiplier.hpp b/src/core/NEON/kernels/arm_conv/depthwise/depthwise_depthfirst_multiplier.hpp
index 2862361..e58467b 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/depthwise_depthfirst_multiplier.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/depthwise_depthfirst_multiplier.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -24,7 +24,8 @@
 
 #pragma once
 
-#include "src/core/NEON/kernels/arm_gemm/utils.hpp"
+#include "depthwise_depthfirst.hpp"
+#include "interleaves/generic_quantized_dot_product.hpp"
 
 #ifdef CYCLE_PROFILING
 #include "profiler.hpp"
@@ -35,492 +36,559 @@
 namespace arm_conv {
 namespace depthwise {
 
-namespace common
+template <typename TInput, typename TWeight, typename TOutput, typename TAccum>
+class DepthfirstMultiplierStrategy : public DepthwiseDepthfirstStrategyCommon<TInput, TWeight, TOutput, TAccum, Nothing>
 {
-  template <typename strategy, typename F>
-  void depthwise_multiplier_execute(
-    const F execute_tile,
-    typename strategy::input_type pad_value,
-    const DepthwiseArgs &args,
-    const unsigned int batches,
-    const unsigned int input_height,
-    const unsigned int input_width,
-    const unsigned int input_channels,
-    const PaddingValues &padding,
-    const void *const _input,
-    const size_t ld_input_col,
-    const size_t ld_input_row,
-    const size_t ld_input_batch,
-    const void *const parameters,
-    const size_t param_stride,
-    const unsigned int output_height,
-    const unsigned int output_width,
-    void *const _output,
-    const size_t ld_output_col,
-    const size_t ld_output_row,
-    const size_t ld_output_batch,
-    void *const _working_space,
-    const unsigned int thread_id,
-    const unsigned int n_threads
-  )
+  using Parent = DepthwiseDepthfirstStrategyCommon<TInput, TWeight, TOutput, TAccum, Nothing>;
+
+  protected:
+  virtual interleaves::PackingArguments get_packing_args(const DepthwiseArgs &args) const
   {
-    using TInput = typename strategy::input_type;
-    using TOutput = typename strategy::return_type;
-
-    // Determine what portion of the work to do.
-    const unsigned int n_rows_per_thread = arm_gemm::iceildiv(output_height, n_threads);
-    const int start_out_height = std::min(thread_id * n_rows_per_thread, output_height);
-    const int end_out_height = std::min(start_out_height + n_rows_per_thread, output_height);
-
-    // Cast input and output pointers into the right types
-    const TInput *const inptr = static_cast<const TInput *>(_input);
-    TOutput *const outptr = static_cast<TOutput *>(_output);
-
-    // To simplify the kernel, we process padded or non-NCHW-ordered input into
-    // a form which can be consumed by the kernel. This data is stored here and
-    // passed into the kernel as an array of N pointers (one per row of the
-    // input).
-    TInput rearranged_input[strategy::input_rows][strategy::input_col_quads*(16 / sizeof(TInput))];
-    const TInput *inptrs[strategy::input_rows];
-
-    // Create an array for the output pointers
-    TOutput * _outptr_array[strategy::output_rows * strategy::output_cols];
-    TOutput **const outptr_array = _outptr_array;
-
-    // Allocate portions of the working space
-    uint8_t *const working_space = static_cast<uint8_t *>(_working_space);
-    TOutput *const output_buffer = reinterpret_cast<TOutput *>(working_space);
-
-    // For each output tile, construct the requisite set of pointers and call
-    // into the kernel.
-    for (unsigned int batch = 0; batch < batches; batch++)
-    {
-      // Get batch pointers
-      const auto inptr_batch = inptr + batch * ld_input_batch;
-      const auto outptr_batch = outptr + batch * ld_output_batch;
-
-      for (int start_out_i = start_out_height;
-           start_out_i < end_out_height;
-           start_out_i += static_cast<int>(strategy::output_rows))
+    return interleaves::PackingArguments(
+      args.kernel_rows, args.kernel_cols, sizeof(TWeight),
+      true, sizeof(TAccum),
+      this->get_vl_type(),
+      sizeof(TAccum), 1,
+      [args] (unsigned int pos, unsigned int &x, unsigned int &y) -> bool
       {
-        const int end_out_i = start_out_i + strategy::output_rows;
-        const int start_in_i = start_out_i * strategy::stride_rows - padding.top;
-        const int end_in_i = start_in_i + strategy::input_rows;
-
-        // Compute top/bottom padding
-        const auto pad_top = static_cast<unsigned int>(-std::min(start_in_i, 0));
-        const auto pad_bottom = static_cast<unsigned int>(-std::min(static_cast<int>(input_height) - end_in_i, 0));
-        const unsigned int valid_output_rows = std::min(
-          end_out_i - start_out_i,
-          static_cast<int>(output_height) - start_out_i
-        );
-
-        for (int start_out_j = 0; start_out_j < static_cast<int>(output_width);)
+        if (pos < args.kernel_rows * args.kernel_cols)
         {
-          const int start_in_j = start_out_j * strategy::stride_cols - args.padding.left;
-          const int pad_left = -std::min(0, start_in_j);
-
-          const int end_out_j = start_out_j + strategy::output_cols;
-          const int end_in_j = start_in_j + strategy::input_cols;
-
-          const auto pad_right = static_cast<unsigned int>(-std::min(static_cast<int>(input_width) - end_in_j, 0));
-          const unsigned int valid_output_cols = std::min(
-            end_out_j - start_out_j,
-            static_cast<int>(output_width) - start_out_j
-          );
-
-          // Construct the output pointer array.
-          TOutput **outptr_pos = outptr_array;
-          for (auto i = 0u; i < valid_output_rows; i++)
-          {
-            unsigned int j = 0u;
-            TOutput *colptr = outptr_batch + (start_out_i + i) * ld_output_row + start_out_j * ld_output_col;
-            for (; j < valid_output_cols; j++)
-            {
-              *(outptr_pos++) = colptr;
-               colptr += ld_output_col;
-            }
-            for (; j < strategy::output_cols; j++)
-            {
-              *(outptr_pos++) = output_buffer;
-            }
-          }
-          for (auto i = valid_output_rows; i < strategy::output_rows; i++)
-          {
-            for (auto j = 0u; j < strategy::output_cols; j++)
-            {
-              *(outptr_pos++) = output_buffer;
-            }
-          }
-
-          start_out_j += strategy::output_cols;
-
-          const uint8_t *params = static_cast<const uint8_t *>(parameters);
-
-          // Loop over the input channels
-          for (unsigned int in_c = 0; in_c < input_channels; in_c++)
-          {
-            // Construct the input array - first fill with padding values and
-            // then fill in correct values.
-            for (unsigned int i = 0; i < strategy::input_rows; i++)
-            {
-              for (unsigned int j = 0;
-                   j < (16 / sizeof(TInput)) * strategy::input_col_quads; j++)
-              {
-                rearranged_input[i][j] = pad_value;
-              }
-              inptrs[i] = rearranged_input[i];
-            }
-
-            auto inptr_row = inptr_batch + in_c +
-                             (start_in_i + pad_top) * ld_input_row +
-                             (start_in_j + pad_left) * ld_input_col;
-            if (ld_input_col == 1 && !pad_left &&
-                start_in_j + (16 / sizeof(TInput)) * strategy::input_col_quads < input_width)
-            {
-              // The input tensor is already in NCHW format, and we're reading
-              // an unpadded section of it - allow the kernel to read it
-              // directly.
-              for (unsigned int i = pad_top; i < strategy::input_rows - pad_bottom; i++)
-              {
-                inptrs[i] = inptr_row;
-                inptr_row += ld_input_row;
-              }
-            }
-            else
-            {
-              // Either the input tensor isn't in NCHW format, or we're reading
-              // a padded section. Copy the relevant portion of the input here
-              // and allow the kernel to read this.
-              for (unsigned int i = pad_top; i < strategy::input_rows - pad_bottom; i++)
-              {
-                auto inptr_col = inptr_row;
-                for (unsigned int j = pad_left; j < strategy::input_cols - pad_right; j++)
-                {
-                  rearranged_input[i][j] = *inptr_col;
-                  inptr_col += ld_input_col;
-                }
-                inptr_row += ld_input_row;
-              }
-            }
-
-            execute_tile(inptrs, outptr_array, params);
-
-            // Progress the output pointers
-            TOutput **outptr_pos = outptr_array;
-            for (auto i = 0u; i < strategy::output_rows * strategy::output_cols; i++)
-            {
-              outptr_pos[i] += args.channel_multiplier;
-            }
-
-            // Progress the pointer into the parameters
-            params += param_stride;
-          }
+          y = pos % args.kernel_cols;
+          x = pos / args.kernel_cols;
+          return true;
         }
+        return false;
       }
-    }
-  }
-}
-
-template <class strategy>
-class DepthwiseDepthfirstWithMultiplier :
-  public DepthwiseCommon<typename strategy::input_type,
-                         typename strategy::weight_type,
-                         typename strategy::return_type>
-{
-  using TInput = typename strategy::input_type;
-  using TWeight = typename strategy::weight_type;
-  using TOutput = typename strategy::return_type;
-  using TAccum = typename strategy::bias_type;
-
-  size_t sizeof_output_buffer(unsigned int n_channels) const
-  {
-    const unsigned int vl = arm_gemm::utils::get_vector_length<TOutput>(strategy::vl_type);
-    const auto rounded_channels = arm_gemm::roundup(n_channels, vl);
-    return sizeof(TOutput) * rounded_channels;
+    );
   }
 
   public:
-  DepthwiseDepthfirstWithMultiplier(const DepthwiseArgs &args) : DepthwiseCommon<TInput, TWeight, TOutput>(args)
+  using Parent::Parent;
+
+  size_t get_storage_size(const DepthwiseArgs &args) const override
+  {
+    return interleaves::get_storage_size_generic(this->get_packing_args(args), args);
+  }
+
+  void pack_parameters(const DepthwiseArgs &args, void *buffer, const void *biases, const Nothing &, const void *weights, size_t ld_weight_col, size_t ld_weight_row) const override
+  {
+    interleaves::pack_parameters_generic(
+      this->get_packing_args(args), args,
+      buffer, biases, weights, ld_weight_col, ld_weight_row
+    );
+  }
+
+  using KernelType = std::function<void(
+    const TInput *const *,  // Input pointers
+    TOutput *const *,  // Output pointers
+    const void *,  // Ravelled bias, weights, and quantization parameters
+    unsigned int,  // # output channels
+    TAccum, TAccum  // Min and max activation clamps
+  )>;
+  virtual KernelType get_kernel(void) const = 0;
+};
+
+
+template <typename TInput, typename TWeight, typename TOutput>
+class DepthfirstMultiplierStrategy<TInput, TWeight, TOutput, int32_t> : public DepthwiseDepthfirstStrategyCommon<TInput, TWeight, TOutput, int32_t, arm_gemm::Requantize32>
+{
+  using Parent = DepthwiseDepthfirstStrategyCommon<TInput, TWeight, TOutput, int32_t, arm_gemm::Requantize32>;
+
+  public:
+  using Parent::Parent;
+
+  size_t get_storage_size(const DepthwiseArgs &args) const override
+  {
+    return interleaves::quantized::get_storage_size(args, this->get_vl_type(), this->get_accumulator_depth_vl());
+  }
+
+  void pack_parameters(const DepthwiseArgs &args, void *buffer, const void *biases, const arm_gemm::Requantize32 &qp, const void *weights, size_t ld_weight_col, size_t ld_weight_row) const override
+  {
+    interleaves::quantized::pack_parameters<TWeight>(
+      buffer, reinterpret_cast<const int32_t *>(biases),
+      reinterpret_cast<const TWeight *>(weights), ld_weight_col, ld_weight_row,
+      args, qp, this->get_vl_type(), this->get_accumulator_depth_vl()
+    );
+  }
+
+  using KernelType = std::function<void(
+    const TInput *const *,  // Input pointers
+    TOutput *const *,  // Output pointers
+    const void *,  // Ravelled bias, weights, and quantization parameters
+    unsigned int,  // # output channels
+    const arm_gemm::Requantize32 &
+  )>;
+  virtual KernelType get_kernel(void) const = 0;
+};
+
+
+template <typename TInput, typename TWeight, typename TOutput, typename TAccum>
+class GenericDepthfirstMultiplierKernelStrategy
+{
+  const arm_gemm::VLType m_vl_type;
+  const unsigned int m_output_rows, m_output_cols;
+
+  public:
+  GenericDepthfirstMultiplierKernelStrategy(unsigned int output_rows, unsigned int output_cols, arm_gemm::VLType vl_type)
+  : m_vl_type(vl_type), m_output_rows(output_rows), m_output_cols(output_cols)
   {
   }
 
-  DepthwiseDepthfirstWithMultiplier(DepthwiseDepthfirstWithMultiplier &) = delete;
-  DepthwiseDepthfirstWithMultiplier &operator=(DepthwiseDepthfirstWithMultiplier &) = delete;
+  virtual ~GenericDepthfirstMultiplierKernelStrategy() = default;
+
+  arm_gemm::VLType get_vl_type(void) const { return m_vl_type; }
+  unsigned int get_output_rows(void) const { return m_output_rows; }
+  unsigned int get_output_cols(void) const { return m_output_cols; }
+
+  using KernelType = std::function<void(
+    const TInput *const *,  // Input pointers
+    TOutput *const *,  // Output pointers
+    const TWeight *,  // Ravelled weight parameters
+    const TAccum *,  // Bias,
+    unsigned int, unsigned int,  // Number of kernel points, number of output channels
+    TAccum, TAccum  // Activation minimum and maximum
+  )>;
+  virtual KernelType get_kernel(void) const = 0;
+};
+
+template <typename TInput, typename TWeight, typename TOutput>
+class GenericDepthfirstMultiplierKernelStrategy<TInput, TWeight, TOutput, int32_t>
+{
+  const arm_gemm::VLType m_vl_type;
+  const unsigned int m_output_rows, m_output_cols;
+
+  public:
+  GenericDepthfirstMultiplierKernelStrategy(unsigned int output_rows, unsigned int output_cols, arm_gemm::VLType vl_type)
+  : m_vl_type(vl_type), m_output_rows(output_rows), m_output_cols(output_cols)
+  {
+  }
+
+  virtual ~GenericDepthfirstMultiplierKernelStrategy() = default;
+
+  arm_gemm::VLType get_vl_type(void) const { return m_vl_type; }
+  unsigned int get_output_rows(void) const { return m_output_rows; }
+  unsigned int get_output_cols(void) const { return m_output_cols; }
+
+  using KernelType = std::function<void(
+    const TInput *const *,  // Input pointers
+    TOutput *const *,  // Output pointers
+    const TWeight *,  // Ravelled weight parameters
+    const int32_t *,  // Bias,
+    unsigned int, unsigned int,  // Number of kernel points, number of output channels
+    const int32_t *, const int32_t *, const int32_t *,  // Per-channel left-shifts, multipliers, right-shifts (need to account for start channel)
+    const arm_gemm::Requantize32 &
+  )>;
+  virtual KernelType get_kernel(void) const = 0;
+};
+
+template <typename TInput,
+          typename TWeight=TInput,
+          typename TOutput=TInput,
+          typename TAccum=typename DefaultTAccum<TInput>::Type,
+          typename OutputStage=typename DefaultOutputStage<TOutput>::Type>
+class GenericDepthfirstMultiplierStrategy : public DepthwiseDepthfirstStrategyCommon<TInput, TWeight, TOutput, TAccum, OutputStage>
+{
+  using KernelStrategyType = GenericDepthfirstMultiplierKernelStrategy<TInput, TWeight, TOutput, TAccum>;
+  std::unique_ptr<KernelStrategyType> m_kern;
+
+  protected:
+  virtual interleaves::PackingArguments get_packing_args(const DepthwiseArgs &args) const
+  {
+    return interleaves::PackingArguments(
+      args.kernel_rows, args.kernel_cols, sizeof(TWeight),
+      false, sizeof(TAccum),
+      this->get_vl_type(),
+      sizeof(TAccum), 1,
+      [args] (unsigned int pos, unsigned int &x, unsigned int &y) -> bool
+      {
+        if (pos < args.kernel_rows * args.kernel_cols)
+        {
+          y = pos % args.kernel_cols;
+          x = pos / args.kernel_cols;
+          return true;
+        }
+        return false;
+      }
+    );
+  }
+
+  public:
+  GenericDepthfirstMultiplierStrategy(KernelStrategyType *kern, const DepthwiseArgs &args)
+  : DepthwiseDepthfirstStrategyCommon<TInput, TWeight, TOutput, TAccum, OutputStage>(
+      kern->get_output_rows(), kern->get_output_cols(),
+      args.kernel_rows, args.kernel_cols,
+      args.stride_rows, args.stride_cols
+    ),
+    m_kern(kern)
+  {
+  };
+
+  arm_gemm::VLType get_vl_type(void) const override { return m_kern->get_vl_type(); }
+  const typename KernelStrategyType::KernelType get_kernel(void) const { return m_kern->get_kernel(); }
+
+  size_t get_storage_size(const DepthwiseArgs &args) const override
+  {
+    return interleaves::get_storage_size_generic(this->get_packing_args(args), args);
+  }
+
+  void pack_parameters(const DepthwiseArgs &args, void *buffer, const void *biases, const OutputStage &, const void *weights, size_t ld_weight_col, size_t ld_weight_row) const override
+  {
+    interleaves::pack_parameters_generic(
+      this->get_packing_args(args), args,
+      buffer, biases, weights, ld_weight_col, ld_weight_row
+    );
+  }
+};
+
+// Specialise elements of the wrapper based on the type of kernel.
+namespace depthfirst_multiplier {
+
+/* Working space element which contains a pointer for each row of input, a row
+ * of padding, and a space which can be used to construct an NCHW-ordered patch
+ * of input.
+ */
+template <typename T, bool IsGeneric=false, typename OutputStage=Nothing>
+class InputPatchElement
+{
+  public:
+  struct Workspace
+  {
+    constexpr static bool InputPatchIsGeneric = IsGeneric;
+    const T **input_rows;
+    T *input_padding;
+    T *input_patch;
+  };
+
+  static size_t get_element_size(const WorkspaceArgs<IDepthfirstStrategy, OutputStage> &args)
+  {
+    return sizeof_input_rows(args) + sizeof_input_padding(args) + sizeof_input_patch(args);
+  }
+
+  template <class WorkspaceType>
+  static void *initialise(WorkspaceType *ws, void *buffer, const WorkspaceArgs<IDepthfirstStrategy, OutputStage> &args)
+  {
+    auto buffer_bytes = reinterpret_cast<char *>(buffer);
+
+    ws->input_rows = reinterpret_cast<const T **>(buffer_bytes);
+    buffer_bytes += sizeof_input_rows(args);
+
+    ws->input_padding = reinterpret_cast<T*>(buffer_bytes);
+    buffer_bytes += sizeof_input_padding(args);
+
+    ws->input_patch = reinterpret_cast<T*>(buffer_bytes);
+    buffer_bytes += sizeof_input_patch(args);
+
+    // Initialise the padding
+    memset(ws->input_padding,
+           get_input_buffer_fill_value(args.output_stage),
+           sizeof_input_padding(args));
+
+    return buffer_bytes;
+  }
+
+  protected:
+  static size_t sizeof_input_rows(const WorkspaceArgs<IDepthfirstStrategy, OutputStage> &args)
+  {
+    if (IsGeneric)
+    {
+      return sizeof(T *) * args.strategy->get_output_rows() * args.depthwise_args.kernel_rows * args.depthwise_args.kernel_cols;
+    }
+    else
+    {
+      return sizeof(T *) * args.strategy->get_input_rows();
+    }
+  }
+
+  static size_t sizeof_input_padding(const WorkspaceArgs<IDepthfirstStrategy, OutputStage> &args)
+  {
+    // Round-up the number of columns to be a whole number of QUADS
+    auto input_cols = arm_gemm::roundup<size_t>(args.strategy->get_input_cols(), 16 / sizeof(T));
+    return sizeof(T) * input_cols;
+  }
+
+  static size_t sizeof_input_patch(const WorkspaceArgs<IDepthfirstStrategy, OutputStage> &args)
+  {
+    if (IsGeneric)
+    {
+      // Round-up the number of columns to be a whole number of QUADS
+      auto output_cols = arm_gemm::roundup<size_t>(args.strategy->get_output_cols(), 16 / sizeof(T));
+      const auto kernel_points = args.depthwise_args.kernel_rows * args.depthwise_args.kernel_cols;
+      return sizeof(T) * kernel_points * args.strategy->get_output_rows() * output_cols;
+    }
+    else
+    {
+      // Round-up the number of columns to be a whole number of QUADS
+      auto input_cols = arm_gemm::roundup<size_t>(args.strategy->get_input_cols(), 16 / sizeof(T));
+      return sizeof(T) * args.strategy->get_input_rows() * input_cols;
+    }
+  }
+};
+
+template <bool IsGeneric, typename TInput, typename TWeight, typename TOutput, typename TAccum, typename OutputStage>
+struct StrategyType
+{
+  using Type = DepthfirstMultiplierStrategy<TInput, TWeight, TOutput, TAccum>;
+
+  template <typename WorkspaceType>
+  static void execute(
+    const DepthwiseArgs &args, const WorkspaceType *ws, const Type *strat,
+    const OutputStage &, const unsigned int,
+    const void *parameters, const void *
+  )
+  {
+    strat->get_kernel()(
+      ws->input_rows,
+      ws->outptr_array,
+      parameters, args.channel_multiplier,
+      ws->activation_min, ws->activation_max
+    );
+  }
+};
+
+template <typename TInput, typename TWeight, typename TOutput, typename TAccum, typename OutputStage>
+struct StrategyType<true, TInput, TWeight, TOutput, TAccum, OutputStage>
+{
+  using Type = GenericDepthfirstMultiplierStrategy<TInput, TWeight, TOutput, TAccum, OutputStage>;
+
+  template <typename WorkspaceType>
+  static void execute(
+    const DepthwiseArgs &args, const WorkspaceType *ws, const Type *strat,
+    const OutputStage &, const unsigned int start_output_channel,
+    const void *parameters, const void *bias
+  )
+  {
+    strat->get_kernel()(
+      ws->input_rows, ws->outptr_array,
+      reinterpret_cast<const TWeight *>(parameters),
+      bias == nullptr ? nullptr : reinterpret_cast<const TAccum *>(bias) + start_output_channel,
+      strat->get_kernel_rows() * strat->get_kernel_cols(),
+      args.channel_multiplier,
+      ws->activation_min, ws->activation_max
+    );
+  }
+};
+
+template <typename TInput, typename TWeight, typename TOutput>
+struct StrategyType<false, TInput, TWeight, TOutput, int32_t, arm_gemm::Requantize32>
+{
+  using Type = DepthfirstMultiplierStrategy<TInput, TWeight, TOutput, int32_t>;
+
+  template <typename WorkspaceType>
+  static void execute(
+    const DepthwiseArgs &args, const WorkspaceType *ws, const Type *strat,
+    const arm_gemm::Requantize32 &qp, const unsigned int,
+    const void *parameters, const void *
+  )
+  {
+    strat->get_kernel()(
+      ws->input_rows,
+      ws->outptr_array,
+      parameters, args.channel_multiplier,
+      qp
+    );
+  }
+};
+
+template <typename TInput, typename TWeight, typename TOutput>
+struct StrategyType<true, TInput, TWeight, TOutput, int32_t, arm_gemm::Requantize32>
+{
+  using Type = GenericDepthfirstMultiplierStrategy<TInput, TWeight, TOutput, int32_t, arm_gemm::Requantize32>;
+
+  template <typename WorkspaceType>
+  static void execute(
+    const DepthwiseArgs &args, const WorkspaceType *ws, const Type *strat,
+    const arm_gemm::Requantize32 &qp, const unsigned int start_output_channel,
+    const void *parameters, const void *
+  )
+  {
+    auto get_ptr = [start_output_channel] (const int32_t *ptr) -> const int32_t *
+    {
+      return ptr == nullptr ? nullptr : ptr + start_output_channel;
+    };
+
+    strat->get_kernel()(
+      ws->input_rows, ws->outptr_array,
+      reinterpret_cast<const TWeight *>(parameters),
+      get_ptr(qp.bias),
+      strat->get_kernel_rows() * strat->get_kernel_cols(),
+      args.channel_multiplier,
+      get_ptr(qp.per_channel_left_shifts),
+      get_ptr(qp.per_channel_muls),
+      get_ptr(qp.per_channel_right_shifts),
+      qp
+    );
+  }
+};
+
+template <bool IsGeneric> struct PrepareInputSample;
+
+template <> struct PrepareInputSample<false>
+{
+  template <typename WorkspaceType, typename StrategyType, typename T>
+  static void execute(
+    const DepthwiseArgs &, WorkspaceType *ws, const StrategyType *strat,
+    T *base_ptr, size_t ld_row, size_t ld_col,
+    const unsigned int input_pad_top, const unsigned int valid_rows,
+    const unsigned int input_pad_left, const unsigned int valid_cols
+  )
+  {
+    fill_nchw_patch_array(
+      ws->input_rows, ws->input_patch, strat->get_input_rows(), strat->get_input_cols(),
+      base_ptr, ld_row, ld_col,
+      ws->input_padding,
+      input_pad_top, valid_rows,
+      input_pad_left, valid_cols
+    );
+  }
+};
+
+template <> struct PrepareInputSample<true>
+{
+  template <typename WorkspaceType, typename StrategyType, typename T>
+  static void execute(
+    const DepthwiseArgs &args, WorkspaceType *ws, const StrategyType *strat,
+    T *base_ptr, size_t ld_row, size_t ld_col,
+    const unsigned int input_pad_top, const unsigned int valid_rows,
+    const unsigned int input_pad_left, const unsigned int valid_cols
+  )
+  {
+    fill_patch_array_generic_kernel(
+      ws->input_rows, ws->input_patch,
+      strat->get_output_rows(), strat->get_output_cols(),
+      args.kernel_rows, args.kernel_cols,
+      args.stride_rows, args.stride_cols,
+      base_ptr, ld_row, ld_col,
+      ws->input_padding,
+      input_pad_top, valid_rows,
+      input_pad_left, valid_cols
+    );
+  }
+};
+
+}  // namespace depthfirst_multiplier
+
+template <typename TInput,
+          typename TWeight=TInput,
+          typename TOutput=TInput,
+          typename TAccum=typename DefaultTAccum<TInput>::Type,
+          bool is_generic=false,
+          typename OutputStage=typename DefaultOutputStage<TOutput>::Type>
+class DepthwiseDepthfirstMultiplier : public DepthfirstDriver<TInput, TWeight, TOutput>
+{
+  protected:
+  using StratType = typename depthfirst_multiplier::StrategyType<is_generic, TInput, TWeight, TOutput, TAccum, OutputStage>::Type;
+  using WorkspaceManager = Workspace<
+    OutputArrayElement<TOutput>,
+    depthfirst_multiplier::InputPatchElement<TInput, is_generic, OutputStage>,
+    ActivationsElement<TOutput, OutputStage>
+  >;
+  using WorkingSpace = typename WorkspaceManager::WorkspaceType;
+
+  OutputStage m_os;  // Copy of the output parameters
+  const void *m_bias = nullptr;  // Copy of the bias (should we need it)
+
+  public:
+  DepthwiseDepthfirstMultiplier(StratType *const strat, const DepthwiseArgs &args, const OutputStage &os = {})
+  : DepthfirstDriver<TInput, TWeight, TOutput>(strat, args), m_os(os)
+  {
+  }
+
+  DepthwiseDepthfirstMultiplier(DepthwiseDepthfirstMultiplier &) = delete;
+  DepthwiseDepthfirstMultiplier &operator=(DepthwiseDepthfirstMultiplier &) = delete;
 
   size_t get_storage_size(void) const override
   {
-    // TODO What if we insert extra padding? Biases are a different size to the inputs, ...
-    const unsigned int vl = arm_gemm::utils::get_vector_length<TInput>(strategy::vl_type);
-    const auto rounded_channels = this->m_args.input_channels * arm_gemm::roundup(this->m_args.channel_multiplier, vl);
-    return (1 + this->m_args.kernel_rows * this->m_args.kernel_cols) * rounded_channels * sizeof(TWeight);
+    return reinterpret_cast<const StratType *>(this->m_strat.get())
+      ->get_storage_size(this->m_args);
   }
 
-  void pack_parameters(void *_buffer, const void *_biases, const void *_weights, size_t ld_weight_col, size_t ld_weight_row) override
+  void pack_parameters(void *buffer, const void *biases, const void *weights, size_t ld_weight_col, size_t ld_weight_row) override
   {
-    // TODO What if the kernel needs a different packing function?
-
-    // Cast the pointers
-    float *buffer = static_cast<float *>(_buffer);
-    const float *biases = static_cast<const float *>(_biases);
-    const float *const weights = static_cast<const float *>(_weights);
-
-    const unsigned int vl = arm_gemm::utils::get_vector_length<TInput>(strategy::vl_type);
-    ld_weight_col = (ld_weight_col == 0) ? this->m_args.channel_multiplier * this->m_args.input_channels : ld_weight_col;
-    ld_weight_row = (ld_weight_row == 0) ? this->m_args.kernel_cols * ld_weight_col : ld_weight_row;
-
-    for (unsigned int in_c = 0; in_c < this->m_args.input_channels; in_c++)
-    {
-      for (unsigned int n = 0; n < this->m_args.channel_multiplier; n += vl)
-      {
-        const unsigned int out_c = in_c * this->m_args.channel_multiplier + n;
-        const unsigned int todo = std::min(vl, this->m_args.channel_multiplier - n);
-
-        // Copy across the correct amount of bias (or 0)
-        for (unsigned int i = 0; i < todo; i++)
-        {
-          buffer[i] = (biases == nullptr) ? 0 : biases[out_c + i];
-        }
-        buffer += vl;
-
-        // Copy each of the weights in turn
-        auto weights_row = weights + out_c;
-        for (unsigned int i = 0; i < this->m_args.kernel_rows; i++)
-        {
-          auto weights_col = weights_row;
-
-          for (unsigned int j = 0; j < this->m_args.kernel_cols; j++)
-          {
-            for (unsigned int m = 0; m < todo; m++)
-            {
-              buffer[m] = weights_col[m];
-            }
-            buffer += vl;
-
-            weights_col += ld_weight_col;
-          }
-
-          weights_row += ld_weight_row;
-        }
-      }
-    }
+    reinterpret_cast<const StratType *>(this->m_strat.get())
+      ->pack_parameters(this->m_args, buffer, biases, m_os, weights, ld_weight_col, ld_weight_row);
+    m_bias = biases;
+    depthwise_depthfirst::stash_bias(m_os, biases);
   }
 
-  size_t get_working_size(const unsigned int n_threads, const unsigned int n_channels) const override
+  size_t get_working_size_per_thread(const unsigned int n_input_channels) const override
   {
-    const unsigned int n_output_channels = n_channels * this->m_args.channel_multiplier;
-    return n_threads * sizeof_output_buffer(n_output_channels);
+    DepthwiseArgs args(this->m_args);
+    args.input_channels = n_input_channels;
+    return WorkspaceManager::get_sizeof_workspace(WorkspaceArgs<IDepthfirstStrategy, OutputStage>(this->m_strat.get(), args, m_os));
   }
-  
-  using DepthwiseCommon<typename strategy::input_type, typename strategy::weight_type, typename strategy::return_type>::execute;
-  void execute(
-    const unsigned int batches,
-    const unsigned int input_height,
-    const unsigned int input_width,
-    const unsigned int input_channels,
-    const PaddingValues &padding,
-    const void *const _input,
-    const size_t ld_input_col,
-    const size_t ld_input_row,
-    const size_t ld_input_batch,
-    const void *const parameters,
-    const unsigned int output_height,
-    const unsigned int output_width,
-    void *const _output,
-    const size_t ld_output_col,
-    const size_t ld_output_row,
-    const size_t ld_output_batch,
-    void *const _working_space,
-    const unsigned int thread_id,
-    const unsigned int n_threads
+
+  void initialise_working_space(void *buffer, unsigned int n_input_channels) const override
+  {
+    DepthwiseArgs args(this->m_args);
+    args.input_channels = n_input_channels;
+    return WorkspaceManager::initialise(buffer, WorkspaceArgs<IDepthfirstStrategy, OutputStage>(this->m_strat.get(), args, m_os));
+  }
+
+  void compute_tile_padded(
+    unsigned int output_i, unsigned int output_j,
+    unsigned int output_channel_start, unsigned int output_channel_end,
+    const TensorSpec<const TInput *> &input,
+    const TensorSpec<TOutput *> &output,
+    const void *parameters,
+    void *working_space_raw
   ) const override
   {
-    strategy strat(this->m_args.cpu_info);
-#ifdef CYCLE_PROFILING
-    arm_gemm::profiler prof;
-#endif
+    // Get the working space
+    auto ws = reinterpret_cast<WorkingSpace *>(working_space_raw);
 
-    // Compute activation values
-    TAccum activation_min = std::numeric_limits<TAccum>::has_infinity ? -std::numeric_limits<TAccum>::infinity() : std::numeric_limits<TAccum>::min();
-    TAccum activation_max = std::numeric_limits<TAccum>::has_infinity ? std::numeric_limits<TAccum>::infinity() : std::numeric_limits<TAccum>::max();
+    const int ii = static_cast<int>(output_i * this->m_args.stride_rows) - this->m_args.padding.top;
+    const auto input_pad_top = static_cast<unsigned int>(ii < 0 ? -ii : 0);
+    const auto input_i = static_cast<unsigned int>(ii < 0 ? 0 : ii);
 
-    switch (this->m_args.activation.type)
+    const int ij = static_cast<int>(output_j * this->m_args.stride_cols) - this->m_args.padding.left;
+    const auto input_pad_left = static_cast<unsigned int>(ij < 0 ? -ij : 0);
+    const auto input_j = static_cast<unsigned int>(ij < 0 ? 0 : ij);
+
+    // Compute the output pointer array. We'll update this array after every
+    // invocation of the kernel.
+    fill_pointer_array(
+      ws->outptr_array, this->m_strat->get_output_rows(), this->m_strat->get_output_cols(),
+      output.base + output_i*output.ld_row + output_j*output.ld_col + output_channel_start,
+      output.ld_row, output.ld_col,
+      ws->output_buffer,
+      0, this->m_args.output_rows - output_i, // Top padding, # valid rows
+      0, this->m_args.output_cols - output_j  // Left padding, # valid columns
+    );
+
+    // Compute the parameter stride
+    DepthwiseArgs single_iter(this->m_args);
+    single_iter.input_channels = 1;
+    const size_t parameter_stride = reinterpret_cast<const StratType *>(this->m_strat.get())
+      ->get_storage_size(single_iter);
+
+    for (; output_channel_start < output_channel_end;
+         output_channel_start += this->m_args.channel_multiplier)
     {
-      case arm_gemm::Activation::Type::BoundedReLU:
-        activation_max = static_cast<TAccum>(this->m_args.activation.param1);
-        // Fall through
-      case arm_gemm::Activation::Type::ReLU:
-        activation_min = static_cast<TAccum>(0);
-        break;
-      default:
-        break;
-    }
+      // Compute the input pointer array
+      const auto input_channel = output_channel_start / this->m_args.channel_multiplier;
 
-    // Determine what portion of the work to do.
-    const unsigned int n_rows_per_thread = arm_gemm::iceildiv(output_height, n_threads);
-    const int start_out_height = std::min(thread_id * n_rows_per_thread, output_height);
-    const int end_out_height = std::min(start_out_height + n_rows_per_thread, output_height);
+      // Construct the input patch
+      depthfirst_multiplier::PrepareInputSample<is_generic>::execute(
+        this->m_args, ws, this->m_strat.get(),
+        input.base + input_channel + input_i*input.ld_row + input_j*input.ld_col, input.ld_row, input.ld_col,
+        input_pad_top, this->m_args.input_rows - input_i,
+        input_pad_left, this->m_args.input_cols - input_j
+      );
 
-    // Need a stride over blocks of parameters
-    const unsigned int vl = arm_gemm::utils::get_vector_length<TOutput>(strategy::vl_type);
-    const unsigned int param_stride =
-      arm_gemm::roundup(this->m_args.channel_multiplier, vl) *
-      (sizeof(TAccum) + sizeof(TWeight) * strategy::kernel_rows * strategy::kernel_cols);
+      // Execute the kernel
+      depthfirst_multiplier::StrategyType<is_generic, TInput, TWeight, TOutput, TAccum, OutputStage>::execute(
+        this->m_args, ws, reinterpret_cast<const StratType *>(this->m_strat.get()), m_os, output_channel_start,
+        parameters, m_bias
+      );
 
-    // Cast input and output pointers into the right types
-    const TInput *const inptr = static_cast<const TInput *>(_input);
-    TOutput *const outptr = static_cast<TOutput *>(_output);
-
-    // To simplify the kernel, we process padded or non-NCHW-ordered input into
-    // a form which can be consumed by the kernel. This data is stored here and
-    // passed into the kernel as an array of N pointers (one per row of the
-    // input).
-    TInput rearranged_input[strategy::input_rows][strategy::input_col_quads*4];
-    const TInput *inptrs[strategy::input_rows];
-
-    // Create an array for the output pointers
-    TOutput * _outptr_array[strategy::output_rows * strategy::output_cols];
-    TOutput **const outptr_array = _outptr_array;
-
-    // Allocate portions of the working space
-    uint8_t *const working_space = static_cast<uint8_t *>(_working_space) + get_working_size(thread_id, input_channels);
-    TOutput *const output_buffer = reinterpret_cast<TOutput *>(working_space);
-
-    // For each output tile, construct the requisite set of pointers and call
-    // into the kernel.
-    for (unsigned int batch = 0; batch < batches; batch++)
-    {
-      // Get batch pointers
-      const auto inptr_batch = inptr + batch * ld_input_batch;
-      const auto outptr_batch = outptr + batch * ld_output_batch;
-
-      for (int start_out_i = start_out_height;
-           start_out_i < end_out_height;
-           start_out_i += static_cast<int>(strategy::output_rows))
+      // Update the output pointers
+      for (unsigned int n = 0; n < this->m_strat->get_output_rows() * this->m_strat->get_output_cols(); n++)
       {
-        const int end_out_i = start_out_i + strategy::output_rows;
-        const int start_in_i = start_out_i * strategy::stride_rows - padding.top;
-        const int end_in_i = start_in_i + strategy::input_rows;
-
-        // Compute top/bottom padding
-        const auto pad_top = static_cast<unsigned int>(-std::min(start_in_i, 0));
-        const auto pad_bottom = static_cast<unsigned int>(-std::min(static_cast<int>(input_height) - end_in_i, 0));
-        const unsigned int valid_output_rows = std::min(
-          end_out_i - start_out_i,
-          static_cast<int>(output_height) - start_out_i
-        );
-
-        for (int start_out_j = 0; start_out_j < static_cast<int>(output_width);)
-        {
-          const int start_in_j = start_out_j * strategy::stride_cols - this->m_args.padding.left;
-          const int pad_left = -std::min(0, start_in_j);
-
-          const int end_out_j = start_out_j + strategy::output_cols;
-          const int end_in_j = start_in_j + strategy::input_cols;
-
-          const auto pad_right = static_cast<unsigned int>(-std::min(static_cast<int>(input_width) - end_in_j, 0));
-          const unsigned int valid_output_cols = std::min(
-            end_out_j - start_out_j,
-            static_cast<int>(output_width) - start_out_j
-          );
-
-          // Construct the output pointer array.
-          TOutput **outptr_pos = outptr_array;
-          for (auto i = 0u; i < valid_output_rows; i++)
-          {
-            unsigned int j = 0u;
-            TOutput *colptr = outptr_batch + (start_out_i + i) * ld_output_row + start_out_j * ld_output_col;
-            for (; j < valid_output_cols; j++)
-            {
-              *(outptr_pos++) = colptr;
-               colptr += ld_output_col;
-            }
-            for (; j < strategy::output_cols; j++)
-            {
-              *(outptr_pos++) = output_buffer;
-            }
-          }
-          for (auto i = valid_output_rows; i < strategy::output_rows; i++)
-          {
-            for (auto j = 0u; j < strategy::output_cols; j++)
-            {
-              *(outptr_pos++) = output_buffer;
-            }
-          }
-
-          start_out_j += strategy::output_cols;
-
-          const uint8_t *params = static_cast<const uint8_t *>(parameters);
-
-          // Loop over the input channels
-          for (unsigned int in_c = 0; in_c < input_channels; in_c++)
-          {
-            // Construct the input array - first fill with padding values and
-            // then fill in correct values.
-            for (unsigned int i = 0; i < strategy::input_rows; i++)
-            {
-              for (unsigned int j = 0; j < 4 * strategy::input_col_quads; j++)
-              {
-                rearranged_input[i][j] = static_cast<TInput>(0);
-              }
-              inptrs[i] = rearranged_input[i];
-            }
-
-            auto inptr_row = inptr_batch + in_c +
-                             (start_in_i + pad_top) * ld_input_row +
-                             (start_in_j + pad_left) * ld_input_col;
-            if (ld_input_col == 1 && !pad_left &&
-                start_in_j + 4 * strategy::input_col_quads < input_width)
-            {
-              // The input tensor is already in NCHW format, and we're reading
-              // an unpadded section of it - allow the kernel to read it
-              // directly.
-              for (unsigned int i = pad_top; i < strategy::input_rows - pad_bottom; i++)
-              {
-                inptrs[i] = inptr_row;
-                inptr_row += ld_input_row;
-              }
-            }
-            else
-            {
-              // Either the input tensor isn't in NCHW format, or we're reading
-              // a padded section. Copy the relevant portion of the input here
-              // and allow the kernel to read this.
-              for (unsigned int i = pad_top; i < strategy::input_rows - pad_bottom; i++)
-              {
-                auto inptr_col = inptr_row;
-                for (unsigned int j = pad_left; j < strategy::input_cols - pad_right; j++)
-                {
-                  rearranged_input[i][j] = *inptr_col;
-                  inptr_col += ld_input_col;
-                }
-                inptr_row += ld_input_row;
-              }
-            }
-
-            {
-#ifdef CYCLE_PROFILING
-              auto p = prof.ScopedProfiler(PROFILE_KERNEL, (unsigned long)(strategy::output_rows * strategy::output_cols * this->m_args.channel_multiplier * strategy::kernel_rows * strategy::kernel_cols));
-#endif
-              strat.kernel(
-                inptrs, outptr_array, params,
-                this->m_args.channel_multiplier,
-                activation_min, activation_max
-              );
-            }
-
-            // Progress the output pointers
-            TOutput **outptr_pos = outptr_array;
-            for (auto i = 0u; i < strategy::output_rows * strategy::output_cols; i++)
-            {
-              outptr_pos[i] += this->m_args.channel_multiplier;
-            }
-
-            // Progress the pointer into the parameters
-            params += param_stride;
-          }
-        }
+        ws->outptr_array[n] += this->m_args.channel_multiplier;
       }
+
+      // Progress the parameters
+      parameters = reinterpret_cast<const char *>(parameters) + parameter_stride;
     }
   }
 };