Update Neon™ depthwise kernel

- Reduce duplication and simplify overall structure.
- Improve multi-threaded performance by sharing more data
  in lower-level caches.

Partially Resolves: COMPMID-5054
Signed-off-by: Ramy Elgammal <ramy.elgammal@arm.com>
Change-Id: Iac747f39b21c540122fa75218762631c4d787911
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/7449
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Andrew Mundy
Reviewed-by: Sheri Zhang <sheri.zhang@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/depthfirst_driver.hpp b/src/core/NEON/kernels/arm_conv/depthwise/depthfirst_driver.hpp
new file mode 100644
index 0000000..e02998f
--- /dev/null
+++ b/src/core/NEON/kernels/arm_conv/depthwise/depthfirst_driver.hpp
@@ -0,0 +1,281 @@
+/*
+ * Copyright (c) 2022 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+
+#include "src/core/NEON/kernels/assembly/depthwise.hpp"
+#include "src/core/NEON/kernels/arm_gemm/utils.hpp"
+
+namespace arm_conv {
+namespace depthwise {
+
+template <typename T> struct DefaultTAccum { using Type = T; };
+template <> struct DefaultTAccum<int8_t> { using Type = int32_t; };
+template <> struct DefaultTAccum<uint8_t> { using Type = int32_t; };
+
+template <typename T> struct DefaultOutputStage { using Type = Nothing; };
+template <> struct DefaultOutputStage<int8_t> { using Type = arm_gemm::Requantize32; };
+template <> struct DefaultOutputStage<uint8_t> { using Type = arm_gemm::Requantize32; };
+
+class IDepthfirstStrategy
+{
+  public:
+  virtual ~IDepthfirstStrategy() = default;
+
+  virtual unsigned int get_input_rows() const = 0;
+  virtual unsigned int get_input_cols() const = 0;
+
+  virtual unsigned int get_output_rows() const = 0;
+  virtual unsigned int get_output_cols() const = 0;
+};
+
+
+template <typename T>
+struct TensorSpec
+{
+  T base;
+  size_t ld_row, ld_col;
+
+  TensorSpec(T ptr, size_t ld_row, size_t ld_col)
+  : base(ptr), ld_row(ld_row), ld_col(ld_col) {}
+};
+
+
+template <typename TInput, typename TWeight, typename TOutput>
+class DepthfirstDriver : public DepthwiseCommon<TInput, TWeight, TOutput>
+{
+  protected:
+  using Parent = DepthwiseCommon<TInput, TWeight, TOutput>;
+
+  // The strategy which we're applying to solve the depthwise convolution.
+  std::unique_ptr<const IDepthfirstStrategy> m_strat;
+
+  /* Compute the amount of working space required for a single thread. */
+  virtual size_t get_working_size_per_thread(unsigned int n_input_channels) const = 0;
+
+  /* Initialise the working space for a thread. */
+  virtual void initialise_working_space(void *, unsigned int n_input_channels) const = 0;
+
+  /* Compute a portion of the output tensor with padding. */
+  virtual void compute_tile_padded(
+    unsigned int output_i, unsigned int output_j,
+    unsigned int output_channel_start, unsigned int output_channel_end,
+    const TensorSpec<const TInput *> &input,
+    const TensorSpec<TOutput *> &output,
+    const void *parameters,
+    void *working_space
+  ) const = 0;
+
+  /* Compute a portion of the work with only top/bottom padding.
+   *
+   * The default implementation of this repeatedly calls into the padded tile
+   * variant.
+   */
+  virtual void compute_row_padded_tile_row(
+    const unsigned int output_i, unsigned int output_j, unsigned int n_tile_cols,
+    const unsigned int output_channel_start, const unsigned int output_channel_end,
+    const TensorSpec<const TInput *> &input,
+    const TensorSpec<TOutput *> &output,
+    const void *parameters,
+    void *working_space
+  ) const
+  {
+    for (; n_tile_cols; n_tile_cols--, output_j += m_strat->get_output_cols())
+    {
+      this->compute_tile_padded(
+        output_i, output_j, output_channel_start, output_channel_end,
+        input, output, parameters, working_space
+      );
+    }
+  }
+
+  /* Compute a portion of the output tensor with no padding.
+   *
+   * The default implementation of this repeatedly calls into the padded
+   * variant.
+   */
+  virtual void compute_tiles_unpadded(
+    unsigned int start_output_i, unsigned int start_output_j,
+    unsigned int n_tile_rows, unsigned int n_tile_cols,
+    unsigned int output_channel_start, unsigned int output_channel_end,
+    const TensorSpec<const TInput *> &input,
+    const TensorSpec<TOutput *> &output,
+    const void *parameters,
+    void *working_space
+  ) const
+  {
+    for (unsigned int tile_i = 0; tile_i < n_tile_rows; tile_i++)
+    {
+      unsigned int row_start_output_j = start_output_j;
+      for (unsigned int tile_j = 0; tile_j < n_tile_cols; tile_j++)
+      {
+        this->compute_tile_padded(
+            start_output_i, row_start_output_j,
+            output_channel_start, output_channel_end,
+            input, output, parameters, working_space
+        );
+        row_start_output_j += m_strat->get_output_cols();
+      }
+      start_output_i += m_strat->get_output_rows();
+    }
+  }
+
+  void execute_internal(
+    unsigned int n_batches,
+    unsigned int input_height,
+    unsigned int input_width,
+    unsigned int n_input_channels,
+    const PaddingValues &padding,
+    const void *input,
+    size_t ld_input_col,
+    size_t ld_input_row,
+    size_t ld_input_batch,
+    const void *parameters,
+    unsigned int output_height,
+    unsigned int output_width,
+    void *output,
+    size_t ld_output_col,
+    size_t ld_output_row,
+    size_t ld_output_batch,
+    void *working_space,
+    unsigned int thread_id,
+    unsigned int n_threads
+  ) const override
+  {
+    // Get and initialise the working space for this thread.
+    void *thread_working_space =
+      static_cast<uint8_t *>(working_space) + thread_id * this->get_working_size_per_thread(n_input_channels);
+    this->initialise_working_space(thread_working_space, n_input_channels);
+
+    // Construct convenient representations of the input/output tensors.
+    TensorSpec<const TInput *> input_tensor(reinterpret_cast<const TInput *>(input), ld_input_row, ld_input_col);
+    TensorSpec<TOutput *> output_tensor(reinterpret_cast<TOutput *>(output), ld_output_row, ld_output_col);
+
+    const auto n_output_channels = n_input_channels * this->m_args.channel_multiplier;
+
+    for (unsigned int batch = 0; batch < n_batches; batch++)
+    {
+      // Iterate over rows of the output tensor; we stripe over the tiles.
+      for (unsigned int start_output_i = thread_id * m_strat->get_output_rows();
+           start_output_i < output_height;
+           start_output_i += n_threads * m_strat->get_output_rows())
+      {
+        // Determine what (if any padding) is required on the top/bottom of
+        // this row of the convolution.
+        const auto end_output_i = start_output_i + m_strat->get_output_rows();
+        const bool pad_output_bottom = output_height < end_output_i;
+
+        const int start_input_i = start_output_i * this->m_args.stride_rows - padding.top;
+        const bool pad_input_top = start_input_i < 0;
+        const int end_input_i = start_input_i + m_strat->get_input_rows();
+        const bool pad_input_bottom = static_cast<int>(input_height) < end_input_i;
+        const bool pad_row = pad_input_top || pad_input_bottom || pad_output_bottom;
+
+        // Iterate over the columns of the output tensor; we attempt to grab as
+        // much as possible of the unpadded regions, so the loop structure is a
+        // bit odd.
+        unsigned int start_output_j = 0;
+        while (start_output_j < output_width)
+        {
+          const int start_in_j = start_output_j * this->m_args.stride_cols - padding.left;
+          const bool pad_input_left = start_in_j < 0;
+
+          // Determine if we can process a number of unpadded tiles in one go.
+          int n_unpadded_tiles = 0;
+          if (!pad_input_left)
+          {
+            // Determine the maximum number of tiles we could handle.
+            n_unpadded_tiles = (output_width - start_output_j) / m_strat->get_output_cols();
+
+            // Handle padding on the right hand edge
+            const int tile_stride = m_strat->get_output_cols() * this->m_args.stride_cols;
+            int end_output_j = start_output_j + n_unpadded_tiles * m_strat->get_output_cols();
+            int end_input_j = start_in_j + m_strat->get_input_cols() + (n_unpadded_tiles - 1)*tile_stride;
+
+            while (n_unpadded_tiles > 0 &&
+                   (static_cast<int>(output_width) < end_output_j ||
+                    static_cast<int>(input_width) < end_input_j))
+            {
+              n_unpadded_tiles--;
+              end_output_j -= m_strat->get_output_cols();
+              end_input_j -= tile_stride;
+            }
+          }
+
+          // Process unpadded tiles, if possible, otherwise process a padded tile.
+          if (n_unpadded_tiles)
+          {
+            if (!pad_row)
+            {
+              // Completely unpadded execution
+              this->compute_tiles_unpadded(
+                start_output_i, start_output_j,
+                1, n_unpadded_tiles,  // Compute a row of unpadded tiles
+                0, n_output_channels,  // Compute all channels
+                input_tensor, output_tensor, parameters, thread_working_space
+              );
+            }
+            else
+            {
+              // Top/bottom padding only
+              this->compute_row_padded_tile_row(
+                start_output_i, start_output_j, n_unpadded_tiles,
+                0, n_output_channels,  // Compute all channels
+                input_tensor, output_tensor, parameters, thread_working_space
+              );
+            }
+            start_output_j += n_unpadded_tiles * m_strat->get_output_cols();
+          }
+          else
+          {
+            this->compute_tile_padded(
+              start_output_i, start_output_j,
+              0, n_output_channels,  // Compute all channels
+              input_tensor, output_tensor, parameters, thread_working_space
+            );
+            start_output_j += m_strat->get_output_cols();
+          }
+        }
+      }
+
+      // Progress the pointers for the next batch.
+      input_tensor.base += ld_input_batch;
+      output_tensor.base += ld_output_batch;
+    }
+  }
+
+  public:
+  DepthfirstDriver(IDepthfirstStrategy *strategy, const DepthwiseArgs &args)
+  : Parent(args), m_strat(strategy)
+  {
+  }
+
+  size_t get_working_size(unsigned int n_threads, unsigned int n_input_channels) const override final
+  {
+    return n_threads * this->get_working_size_per_thread(n_input_channels);
+  }
+};
+
+}  // namespace depthwise
+}  // namespace arm_conv
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/depthwise_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/depthwise_depthfirst.hpp
index 57fa111..6905076 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/depthwise_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/depthwise_depthfirst.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -24,7 +24,9 @@
 
 #pragma once
 
-#include "src/core/NEON/kernels/arm_gemm/utils.hpp"
+#include "src/core/NEON/kernels/arm_conv/addressing.hpp"
+#include "depthwise_strategies_common.hpp"
+#include "working_space.hpp"
 
 #ifdef CYCLE_PROFILING
 #include "profiler.hpp"
@@ -35,349 +37,547 @@
 namespace arm_conv {
 namespace depthwise {
 
-struct IDepthwiseDepthfirstStrategy
+template <typename TInput, typename TWeight, typename TOutput, typename TAccum,
+          typename OutputStage>
+class DepthwiseDepthfirstStrategyCommon
+  : public DepthfirstStrategy<TInput, TWeight, TOutput, TAccum, OutputStage>
 {
-  virtual arm_gemm::VLType get_vl_type() const = 0;
+  protected:
+  unsigned int m_output_rows, m_output_cols;
+  unsigned int m_kernel_rows, m_kernel_cols;
+  unsigned int m_stride_rows, m_stride_cols;
 
-  virtual unsigned int get_input_rows() const = 0;
-  virtual unsigned int get_input_cols() const = 0;
+  public:
+  DepthwiseDepthfirstStrategyCommon(
+    unsigned int output_rows, unsigned int output_cols,
+    unsigned int kernel_rows, unsigned int kernel_cols,
+    unsigned int stride_rows=1, unsigned int stride_cols=1
+  ) : m_output_rows(output_rows), m_output_cols(output_cols),
+      m_kernel_rows(kernel_rows), m_kernel_cols(kernel_cols),
+      m_stride_rows(stride_rows), m_stride_cols(stride_cols)
+  {
+  }
 
-  virtual unsigned int get_output_rows() const = 0;
-  virtual unsigned int get_output_cols() const = 0;
+  DepthwiseDepthfirstStrategyCommon(unsigned int output_size, unsigned int kernel_size, unsigned int stride=1)
+  : DepthwiseDepthfirstStrategyCommon(output_size, output_size, kernel_size, kernel_size, stride, stride)
+  {
+  }
 
-  virtual unsigned int get_kernel_rows() const = 0;
-  virtual unsigned int get_kernel_cols() const = 0;
+  virtual ~DepthwiseDepthfirstStrategyCommon() {}
 
-  virtual unsigned int get_stride_rows() const = 0;
-  virtual unsigned int get_stride_cols() const = 0;
+  unsigned int get_output_rows() const override { return m_output_rows; }
+  unsigned int get_output_cols() const override { return m_output_cols; }
 
-  virtual void indirect_kernel(
-    const void *const *const input_ptrs,
-    void *const *const output_ptrs,
-    const void *params,
-    unsigned int n_channels,
-    const void *activation_min,
-    const void *activation_max
-  ) const = 0;
+  unsigned int get_kernel_rows() const override { return m_kernel_rows; }
+  unsigned int get_kernel_cols() const override { return m_kernel_cols; }
 
-  virtual void direct_kernel(
-    const unsigned int n_tile_rows, const unsigned int n_tile_cols,
-    const void *inptr, int64_t ld_input_row, int64_t ld_input_col,
-    void *outptr, int64_t ld_output_row, int64_t ld_output_col,
-    const void *params, unsigned int n_channels,
-    const void *activation_min,
-    const void *activation_max
-  ) const = 0;
-
-  virtual ~IDepthwiseDepthfirstStrategy() {}
+  unsigned int get_stride_rows() const override { return m_stride_rows; }
+  unsigned int get_stride_cols() const override { return m_stride_cols; }
 };
 
-template <typename TInput, typename TWeight, typename TOutput, typename TAccum>
-class DepthwiseDepthfirst : public DepthwiseCommon<TInput, TWeight, TOutput>
+template <typename TInput, typename TWeight, typename TOutput, typename TAccum, typename OutputStage=typename DefaultOutputStage<TOutput>::Type>
+class DepthwiseDepthfirstStrategy : public DepthwiseDepthfirstStrategyCommon<TInput, TWeight, TOutput, TAccum, OutputStage>
 {
-  const std::unique_ptr<IDepthwiseDepthfirstStrategy> m_strat;
+  using Parent = DepthwiseDepthfirstStrategyCommon<TInput, TWeight, TOutput, TAccum, OutputStage>;
 
-  size_t sizeof_inptr_array(void) const
-  {
-    return sizeof(TInput *) * m_strat->get_input_rows() * m_strat->get_input_cols();
-  }
+  public:
+  using Parent::Parent;
 
-  size_t sizeof_input_buffer(unsigned int n_input_channels) const
-  {
-   return sizeof(TInput) * n_input_channels;
-  }
+  typedef void (*IndirectKernelType)(
+    const TInput *const *input_ptrs,
+    TOutput *const *output_ptrs,
+    const void *params,
+    unsigned int n_channels,
+    const TAccum activation_min,
+    const TAccum activation_max
+  );
+  virtual IndirectKernelType get_indirect_kernel(void) const = 0;
 
-  size_t sizeof_outptr_array(void) const
-  {
-    return sizeof(TInput *) * m_strat->get_output_rows() * m_strat->get_output_cols();
-  }
+  typedef void (*DirectKernelType)(
+    const unsigned int n_tile_rows, const unsigned int n_tile_cols,
+    const TInput *inptr_base, int64_t ld_input_row, int64_t ld_input_col,
+    TOutput *outptr_base, int64_t ld_output_row, int64_t ld_output_col,
+    const void *params, unsigned int n_channels,
+    const TAccum activation_min,
+    const TAccum activation_max
+  );
+  virtual DirectKernelType get_direct_kernel(void) const = 0;
+};
 
-  size_t sizeof_output_buffer(unsigned int n_output_channels) const
+template <typename TInput, typename TWeight, typename TOutput>
+class DepthwiseDepthfirstStrategy<TInput, TWeight, TOutput, int32_t>
+: public DepthwiseDepthfirstStrategyCommon<TInput, TWeight, TOutput, int32_t, arm_gemm::Requantize32>
+{
+  using Parent = DepthwiseDepthfirstStrategyCommon<TInput, TWeight, TOutput, int32_t, arm_gemm::Requantize32>;
+
+  protected:
+  interleaves::PackingArguments get_packing_args(void) const
   {
-    return sizeof(TOutput) * n_output_channels;
+    return interleaves::PackingArguments(
+      this->get_kernel_rows(), this->get_kernel_cols(), sizeof(TWeight),
+      false, sizeof(int32_t),  // Don't pack the bias
+      this->get_vl_type(), sizeof(int32_t), this->get_accumulator_depth_vl(),
+      [this] (unsigned int idx, unsigned int &x, unsigned int &y) -> bool
+      { return this->get_kernel_packing_point(idx, x, y); }
+    );
   }
 
   public:
-  DepthwiseDepthfirst(
-    IDepthwiseDepthfirstStrategy *const strat,
-    const DepthwiseArgs &args
-  ) : DepthwiseCommon<TInput, TWeight, TOutput>(args), m_strat(strat)
+  using Parent::Parent;
+
+  typedef void (*KernelType)(
+    unsigned int,  //  n_channels,
+    const TInput *const *,  // inptrs
+    const TWeight *,  // weights
+    const int32_t *,  //  bias,
+    const arm_gemm::Requantize32 &,
+    const int32_t *, const int32_t *,  //  requant_muls and requant_shifts
+    TOutput *const *  // outptrs
+  );
+  virtual KernelType get_kernel() const = 0;
+
+  size_t get_storage_size(const DepthwiseArgs &args) const override
+  {
+    return interleaves::get_storage_size_generic(get_packing_args(), args);
+  }
+
+  void pack_parameters(
+    const DepthwiseArgs &args, void *buffer,
+    const void *biases, const arm_gemm::Requantize32 &,
+    const void *weights, size_t ld_weight_col, size_t ld_weight_row
+  ) const override
+  {
+    interleaves::pack_parameters_generic(
+      get_packing_args(), args, buffer, biases, weights, ld_weight_col, ld_weight_row);
+  }
+};
+
+template <typename TInput, typename TWeight, typename TOutput, typename TAccum, typename OutputStage>
+class DepthwiseDepthfirstCommon : public DepthfirstDriver<TInput, TWeight, TOutput>
+{
+  using StratType = DepthwiseDepthfirstStrategyCommon<TInput, TWeight, TOutput, TAccum, OutputStage>;
+  OutputStage m_os;
+
+  protected:
+  inline OutputStage &get_output_stage(void) { return m_os; }
+  inline const OutputStage &get_output_stage(void) const { return m_os; }
+
+  public:
+  DepthwiseDepthfirstCommon(StratType *const strat, const DepthwiseArgs &args, const OutputStage &os)
+  : DepthfirstDriver<TInput, TWeight, TOutput>(strat, args), m_os(os)
+  {
+  }
+
+  DepthwiseDepthfirstCommon(DepthwiseDepthfirstCommon &) = delete;
+  DepthwiseDepthfirstCommon &operator=(DepthwiseDepthfirstCommon &) = delete;
+
+  size_t get_storage_size(void) const override
+  {
+    return reinterpret_cast<const StratType *>(this->m_strat.get())->
+      get_storage_size(this->m_args);
+  }
+
+  void pack_parameters(void *buffer, const void *biases, const void *weights, size_t ld_weight_col, size_t ld_weight_row) override
+  {
+    reinterpret_cast<const StratType *>(this->m_strat.get())->
+      pack_parameters(this->m_args, buffer, biases, m_os, weights, ld_weight_col, ld_weight_row);
+  }
+};
+
+namespace depthwise_depthfirst {
+
+/* Workspace Element for an array of input pointers as consumed by the
+ * specialised depthwise kernels.
+ */
+template <typename T>
+class InputArrayElement
+{
+  public:
+  struct Workspace
+  {
+    const T **inptr_array;
+  };
+
+  template <class OutputStage>
+  static size_t get_element_size(const WorkspaceArgs<IDepthfirstStrategy, OutputStage> &args)
+  {
+    return sizeof(T **) * args.strategy->get_input_rows() * args.strategy->get_input_cols();
+  }
+
+  template <class WorkspaceType, class OutputStage>
+  static void *initialise(WorkspaceType *ws, void *buffer, const WorkspaceArgs<IDepthfirstStrategy, OutputStage> &args)
+  {
+    ws->inptr_array = reinterpret_cast<const T**>(buffer);
+    return reinterpret_cast<char *>(buffer) + get_element_size(args);
+  }
+};
+
+template <typename TAccum, typename OutputStage, bool IsDot=false>
+struct WorkspaceFinalElement
+{
+  using Element = ActivationsElement<TAccum, OutputStage>;
+};
+
+template <>
+struct WorkspaceFinalElement<int32_t, arm_gemm::Requantize32, false>
+{
+  using Element = RequantizationParametersElement;
+};
+
+template <typename TInput, typename TWeight, typename TOutput, typename TAccum, typename OutputStage>
+struct Invoke
+{
+  constexpr static bool supports_direct_kernel = true;
+
+  template <typename Strat, typename Workspace>
+  static inline void indirect(const Strat *strat, const Workspace *ws, const OutputStage &, const void *params, const TAccum *, unsigned int n_channels)
+  {
+    strat->get_indirect_kernel()(
+      ws->inptr_array,
+      ws->outptr_array,
+      params, n_channels,
+      ws->activation_min, ws->activation_max
+    );
+  }
+
+  template <typename Strat, typename Workspace>
+  static void direct(
+    const Strat *strat, const Workspace *ws, const OutputStage &,
+    unsigned int n_tile_rows, unsigned int n_tile_cols,
+    const TInput *inptr, size_t ld_in_row, size_t ld_in_col,
+    TOutput *outptr, size_t ld_out_row, size_t ld_out_col,
+    const void *params, unsigned int n_channels
+  )
+  {
+    strat->get_direct_kernel()(
+      n_tile_rows, n_tile_cols,
+      inptr, ld_in_row, ld_in_col,
+      outptr, ld_out_row, ld_out_col,
+      params, n_channels, ws->activation_min, ws->activation_max
+    );
+  }
+};
+
+template <typename TInput, typename TWeight, typename TOutput, typename TAccum>
+struct Invoke<TInput, TWeight, TOutput, TAccum, arm_gemm::Requantize32>
+{
+  constexpr static bool supports_direct_kernel = false;
+
+  template <typename Strat, typename Workspace>
+  static inline void indirect(const Strat *strat, const Workspace *ws, const arm_gemm::Requantize32 &qp, const void *params, const TAccum *, unsigned int n_channels)
+  {
+    strat->get_kernel()(
+      n_channels, ws->inptr_array,
+      reinterpret_cast<const TWeight *>(params), ws->bias,
+      qp, ws->requant_muls, ws->requant_shifts,
+      ws->outptr_array
+    );
+  }
+
+  template <typename Strat, typename Workspace>
+  static inline void direct(
+    const Strat *, const Workspace *, const arm_gemm::Requantize32 &,
+    unsigned int, unsigned int,  // n_tile_rows, n_tile_cols
+    const TInput *, size_t, size_t,  // Input pointer, row stride, column stride
+    TOutput *, size_t, size_t,  // Output pointer, row stride, column stride
+    const void *, unsigned int  // Parameters, number of channels
+  )
+  {
+    // Do nothing - this should never be reached because entry to it is guarded
+    // by an `if` on a `constexpr static bool`.
+  }
+};
+
+namespace
+{
+
+template <typename OutputStage>
+inline void stash_bias(OutputStage &, const void *) {}
+
+template <>
+inline void stash_bias(arm_gemm::Requantize32 &qp, const void *bias) __attribute__ ((unused));
+
+template <>
+inline void stash_bias(arm_gemm::Requantize32 &qp, const void *bias)
+{
+  qp.bias = reinterpret_cast<const int32_t *>(bias);
+}
+
+}
+
+}  // namespace depthwise_depthfirst
+
+template <typename TInput,
+          typename TWeight=TInput,
+          typename TOutput=TInput,
+          typename TAccum=typename DefaultTAccum<TInput>::Type,
+          typename OutputStage=typename DefaultOutputStage<TOutput>::Type>
+class DepthwiseDepthfirst
+: public DepthwiseDepthfirstCommon<TInput, TWeight, TOutput, TAccum, OutputStage>
+{
+  using StratType = DepthwiseDepthfirstStrategy<TInput, TWeight, TOutput, TAccum>;
+  using Parent = DepthwiseDepthfirstCommon<TInput, TWeight, TOutput, TAccum, OutputStage>;
+  using WorkspaceManager = Workspace<
+    OutputArrayElement<TOutput>,
+    depthwise_depthfirst::InputArrayElement<TInput>,
+    InputBufferElement<TInput>,
+    typename depthwise_depthfirst::WorkspaceFinalElement<TAccum, OutputStage>::Element
+  >;
+  using WorkingSpace = typename WorkspaceManager::WorkspaceType;
+
+  // We keep a copy of the bias and output stage
+  const TAccum *m_bias;
+
+  public:
+  DepthwiseDepthfirst(StratType *const strat, const DepthwiseArgs &args, const OutputStage &os = {})
+  : Parent(strat, args, os), m_bias(nullptr)
   {
   }
 
   DepthwiseDepthfirst(DepthwiseDepthfirst &) = delete;
   DepthwiseDepthfirst &operator=(DepthwiseDepthfirst &) = delete;
 
-  size_t get_storage_size(void) const override
+  void pack_parameters(void *buffer, const void *biases, const void *weights, size_t ld_weight_col, size_t ld_weight_row) override
   {
-    // TODO What if we insert extra padding? Biases are a different size to the inputs, ...
-    const unsigned int vl = arm_gemm::utils::get_vector_length<TInput>(m_strat->get_vl_type());
-    const auto rounded_channels = arm_gemm::roundup(this->m_args.input_channels, vl);
-    return (1 + this->m_args.kernel_rows * this->m_args.kernel_cols) * rounded_channels * sizeof(TWeight);
+    reinterpret_cast<const StratType *>(this->m_strat.get())->pack_parameters(
+      this->m_args, buffer, biases, this->get_output_stage(),
+      weights, ld_weight_col, ld_weight_row
+    );
+    m_bias = reinterpret_cast<const TAccum *>(biases);
+    depthwise_depthfirst::stash_bias(this->get_output_stage(), biases);
   }
 
-  void pack_parameters(void *_buffer, const void *_biases, const void *_weights, size_t ld_weight_col, size_t ld_weight_row) override
+  size_t get_working_size_per_thread(const unsigned int n_input_channels) const override
   {
-    // TODO What if the kernel needs a different packing function?
-
-    // Cast the pointers
-    uint8_t *buffer = static_cast<uint8_t *>(_buffer);
-    const TAccum *biases = static_cast<const TAccum *>(_biases);
-    const TWeight *const weights = static_cast<const TWeight *>(_weights);
-
-    const unsigned int vl = arm_gemm::utils::get_vector_length<TAccum>(m_strat->get_vl_type());
-    ld_weight_col = (ld_weight_col == 0) ? this->m_args.input_channels : ld_weight_col;
-    ld_weight_row = (ld_weight_row == 0) ? this->m_args.kernel_cols * ld_weight_col : ld_weight_row;
-
-    for (unsigned int n = 0; n < this->m_args.input_channels; n += vl)
-    {
-      const unsigned int todo = std::min(vl, this->m_args.input_channels - n);
-
-      // Copy across the correct amount of bias (or 0)
-      for (unsigned int i = 0; i < todo; i++)
-      {
-        reinterpret_cast<TAccum *>(buffer)[i] = (biases == nullptr) ? 0 : biases[n + i];
-      }
-      buffer += vl * sizeof(TAccum);
-
-      // Copy each of the weights in turn
-      auto weights_row = weights + n;
-      for (unsigned int i = 0; i < this->m_args.kernel_rows; i++)
-      {
-        auto weights_col = weights_row;
-
-        for (unsigned int j = 0; j < this->m_args.kernel_cols; j++)
-        {
-          for (unsigned int m = 0; m < todo; m++)
-          {
-            reinterpret_cast<TWeight *>(buffer)[m] = weights_col[m];
-          }
-          buffer += vl * sizeof(TWeight);
-
-          weights_col += ld_weight_col;
-        }
-
-        weights_row += ld_weight_row;
-      }
-    }
+    DepthwiseArgs args(this->m_args);
+    args.input_channels = n_input_channels;
+    return WorkspaceManager::get_sizeof_workspace(
+      WorkspaceArgs<IDepthfirstStrategy, OutputStage>(this->m_strat.get(), args, this->get_output_stage())
+    );
   }
 
-  size_t get_working_size(const unsigned int n_threads, const unsigned int n_channels) const override
+  void initialise_working_space(void *buffer, unsigned int n_input_channels) const override
   {
-    const unsigned int n_output_channels = n_channels * this->m_args.channel_multiplier;
-    return n_threads * (sizeof_inptr_array() + sizeof_outptr_array() +
-                        sizeof_output_buffer(n_output_channels) +
-                        sizeof_input_buffer(n_channels));
+    DepthwiseArgs args(this->m_args);
+    args.input_channels = n_input_channels;
+    WorkspaceManager::initialise(
+      buffer, WorkspaceArgs<IDepthfirstStrategy, OutputStage>(this->m_strat.get(), args, this->get_output_stage())
+    );
   }
 
-  using DepthwiseCommon<TInput, TWeight, TOutput>::execute;
-  void execute(
-    const unsigned int batches,
-    const unsigned int input_height,
-    const unsigned int input_width,
-    const unsigned int input_channels,
-    const PaddingValues &padding,
-    const void *const _input,
-    const size_t ld_input_col,
-    const size_t ld_input_row,
-    const size_t ld_input_batch,
-    const void *const parameters,
-    const unsigned int output_height,
-    const unsigned int output_width,
-    void *const _output,
-    const size_t ld_output_col,
-    const size_t ld_output_row,
-    const size_t ld_output_batch,
-    void *const _working_space,
-    const unsigned int thread_id,
-    const unsigned int n_threads
+  protected:
+  void compute_tile_padded(
+    unsigned int output_i, unsigned int output_j,
+    unsigned int output_channel_start, unsigned int output_channel_end,
+    const TensorSpec<const TInput *> &input,
+    const TensorSpec<TOutput *> &output,
+    const void *parameters,
+    void *working_space_raw
   ) const override
   {
-#ifdef CYCLE_PROFILING
-    arm_gemm::profiler prof;
-#endif
+    // Get the working space
+    auto ws = reinterpret_cast<WorkingSpace *>(working_space_raw);
 
-    // Compute activation values
-    TAccum activation_min, activation_max;
-    std::tie(activation_min, activation_max) = get_default_activation_values<TAccum>();
+    // Compute the input pointer array
+    const auto input_channel_start = output_channel_start / this->m_args.channel_multiplier;
 
-    switch (this->m_args.activation.type)
+    const int ii = static_cast<int>(output_i * this->m_args.stride_rows) - this->m_args.padding.top;
+    const auto input_pad_top = static_cast<unsigned int>(ii < 0 ? -ii : 0);
+    const auto input_i = static_cast<unsigned int>(ii < 0 ? 0 : ii);
+
+    const int ij = static_cast<int>(output_j * this->m_args.stride_cols) - this->m_args.padding.left;
+    const auto input_pad_left = static_cast<unsigned int>(ij < 0 ? -ij : 0);
+    const auto input_j = static_cast<unsigned int>(ij < 0 ? 0 : ij);
+
+    fill_pointer_array<const TInput>(
+      ws->inptr_array, this->m_strat->get_input_rows(), this->m_strat->get_input_cols(),
+      input.base + input_i*input.ld_row + input_j*input.ld_col + input_channel_start,
+      input.ld_row, input.ld_col,
+      ws->input_buffer,
+      input_pad_top, this->m_args.input_rows - input_i,
+      input_pad_left, this->m_args.input_cols - input_j
+    );
+
+    // Compute the output pointer array
+    fill_pointer_array(
+      ws->outptr_array, this->m_strat->get_output_rows(), this->m_strat->get_output_cols(),
+      output.base + output_i*output.ld_row + output_j*output.ld_col + output_channel_start,
+      output.ld_row, output.ld_col,
+      ws->output_buffer,
+      0, this->m_args.output_rows - output_i, // Top padding, # valid rows
+      0, this->m_args.output_cols - output_j  // Left padding, # valid columns
+    );
+
+    // Execute the kernel
+    depthwise_depthfirst::Invoke<TInput, TWeight, TOutput, TAccum, OutputStage>::indirect(
+      reinterpret_cast<const StratType *>(this->m_strat.get()),
+      ws, this->get_output_stage(), parameters, m_bias, output_channel_end - output_channel_start
+    );
+  }
+
+  void compute_row_padded_tile_row(
+    const unsigned int output_i, unsigned int output_j, unsigned int n_tile_cols,
+    const unsigned int output_channel_start, const unsigned int output_channel_end,
+    const TensorSpec<const TInput *> &input,
+    const TensorSpec<TOutput *> &output,
+    const void *parameters,
+    void *working_space
+  ) const override
+  {
+    using Invoker = depthwise_depthfirst::Invoke<TInput, TWeight, TOutput, TAccum, OutputStage>;
+    auto ws = reinterpret_cast<WorkingSpace *>(working_space);
+    const auto strat = reinterpret_cast<const StratType *>(this->m_strat.get());
+    const auto os = this->get_output_stage();
+
+    // Compute top and bottom padding; hence fill in the initial pointer arrays.
+    const auto input_channel_start = output_channel_start / this->m_args.channel_multiplier;
+    const int ii = static_cast<int>(output_i * this->m_args.stride_rows) - this->m_args.padding.top;
+    const auto input_pad_top = static_cast<unsigned int>(ii < 0 ? -ii : 0);
+
+    const auto input_i = static_cast<unsigned int>(ii < 0 ? 0 : ii);
+    const auto input_j = output_j * this->m_args.stride_cols - this->m_args.padding.left;
+
+    const auto valid_input_rows = std::min(strat->get_input_rows(), this->m_args.input_rows - input_i);
+    const auto valid_output_rows = std::min(strat->get_output_rows(), this->m_args.output_rows - output_i);
+
+    const auto input_point_stride = input.ld_col * this->m_strat->get_output_cols() * this->m_args.stride_cols;
+    const auto output_point_stride = output.ld_col * this->m_strat->get_output_cols();
+
+    fill_pointer_array<const TInput>(
+      ws->inptr_array, this->m_strat->get_input_rows(), this->m_strat->get_input_cols(),
+      input.base + input_i*input.ld_row + input_j*input.ld_col + input_channel_start,
+      input.ld_row, input.ld_col,
+      ws->input_buffer,
+      input_pad_top, this->m_args.input_rows - input_i,
+      0, this->m_args.input_cols - input_j  // No left padding
+    );
+
+    fill_pointer_array(
+      ws->outptr_array, this->m_strat->get_output_rows(), this->m_strat->get_output_cols(),
+      output.base + output_i*output.ld_row + output_j*output.ld_col + output_channel_start,
+      output.ld_row, output.ld_col,
+      ws->output_buffer,
+      0, this->m_args.output_rows - output_i,  // Top padding, # valid rows
+      0, this->m_args.output_cols - output_j  // Left padding, # valid columns
+    );
+
+    for (; n_tile_cols; n_tile_cols--)
     {
-      case arm_gemm::Activation::Type::BoundedReLU:
-        activation_max = static_cast<TAccum>(this->m_args.activation.param1);
-        // Fall through
-      case arm_gemm::Activation::Type::ReLU:
-        activation_min = static_cast<TAccum>(0);
-        break;
-      default:
-        break;
-    }
+      // Execute the kernel
+      Invoker::indirect(
+        strat, ws, os, parameters, m_bias, output_channel_end - output_channel_start
+      );
 
-    // Determine what portion of the work to do.
-    const unsigned int n_rows_per_thread = arm_gemm::iceildiv(output_height, n_threads);
-    const int start_out_height = std::min(thread_id * n_rows_per_thread, output_height);
-    const int end_out_height = std::min(start_out_height + n_rows_per_thread, output_height);
-
-    // Cast input and output pointers into the right types
-    const TInput *const inptr = static_cast<const TInput *>(_input);
-    TOutput *const outptr = static_cast<TOutput *>(_output);
-
-    // Allocate portions of the working space
-    uint8_t *working_space = static_cast<uint8_t *>(_working_space) + get_working_size(thread_id, input_channels);
-
-    const void **const inptr_array = reinterpret_cast<const void **>(working_space);
-    working_space += sizeof_inptr_array();
-
-    void **const outptr_array = reinterpret_cast<void **>(working_space);
-    working_space += sizeof_outptr_array();
-
-    TOutput *const output_buffer = reinterpret_cast<TOutput *>(working_space);
-    working_space += sizeof_output_buffer(input_channels * this->m_args.channel_multiplier);
-
-    TInput *const input_buffer = reinterpret_cast<TInput *>(working_space);
-
-    // Initialise the input buffer
-    for (unsigned int c = 0; c < input_channels; c++)
-    {
-      input_buffer[c] = static_cast<TInput>(0);
-    }
-
-    // For each output tile, construct the requisite set of pointers and call
-    // into the kernel.
-    for (unsigned int batch = 0; batch < batches; batch++)
-    {
-      // Get batch pointers
-      const auto inptr_batch = inptr + batch * ld_input_batch;
-      const auto outptr_batch = outptr + batch * ld_output_batch;
-
-      for (int start_out_i = start_out_height;
-           start_out_i < end_out_height;
-           start_out_i += static_cast<int>(m_strat->get_output_rows()))
+      // Update all unpadded pointers
       {
-        const int end_out_i = start_out_i + m_strat->get_output_rows();
-        const int start_in_i = start_out_i * m_strat->get_stride_rows() - padding.top;
-        const int end_in_i = start_in_i + m_strat->get_input_rows();
+        auto ptr = ws->inptr_array + strat->get_input_cols() * input_pad_top;
+        for (auto n = input_pad_top; n < valid_input_rows; n++)
+        {
+          for (auto m = 0u; m < strat->get_input_cols(); m++)
+          {
+            *(ptr++) += input_point_stride;
+          }
+        }
+      }
+      {
+        auto ptr = ws->outptr_array;
+        for (auto n = 0u; n < valid_output_rows * strat->get_output_cols(); n++)
+        {
+          *(ptr++) += output_point_stride;
+        }
+      }
+    }
+  }
 
-        // Compute top/bottom padding
-        const auto pad_top = static_cast<unsigned int>(-std::min(start_in_i, 0));
-        const auto pad_bottom = static_cast<unsigned int>(-std::min(static_cast<int>(input_height) - end_in_i, 0));
-        const unsigned int valid_output_rows = std::min(
-          end_out_i - start_out_i,
-          static_cast<int>(output_height) - start_out_i
+  void compute_tiles_unpadded(
+    unsigned int output_i, const unsigned int output_j,
+    unsigned int n_tile_rows, unsigned int n_tile_cols,
+    unsigned int output_channel_start, unsigned int output_channel_end,
+    const TensorSpec<const TInput *> &input,
+    const TensorSpec<TOutput *> &output,
+    const void *parameters,
+    void *working_space_raw
+  ) const override
+  {
+    using Invoker = depthwise_depthfirst::Invoke<TInput, TWeight, TOutput, TAccum, OutputStage>;
+    auto ws = reinterpret_cast<WorkingSpace *>(working_space_raw);
+    const auto strat = reinterpret_cast<const StratType *>(this->m_strat.get());
+    const auto os = this->get_output_stage();
+
+    if (Invoker::supports_direct_kernel)
+    {
+      // If the direct kernel is supported, then use it.
+      // Compute the base pointers we'll use in the tile.
+      auto outptr = output.base + output_channel_start + output_i * output.ld_row + output_j * output.ld_col;
+      const int start_input_i = output_i * this->m_args.stride_rows - this->m_args.padding.top;
+      const int start_input_j = output_j * this->m_args.stride_cols - this->m_args.padding.left;
+      auto inptr = input.base + output_channel_start + start_input_i * input.ld_row + start_input_j * input.ld_col;
+
+      // Execute the kernel
+      Invoker::direct(
+        strat, ws, os,
+        n_tile_rows, n_tile_cols,
+        inptr, input.ld_row, input.ld_col,
+        outptr, output.ld_row, output.ld_col,
+        parameters, output_channel_end - output_channel_start
+      );
+    }
+    else
+    {
+      // Otherwise, we repeatedly call the padded kernel but use our knowledge
+      // of the tensor structure to avoid recomputing the pointer array.
+      const auto input_channel_start = output_channel_start / this->m_args.channel_multiplier;
+
+      const auto n_input_pointers = this->m_strat->get_input_rows() * this->m_strat->get_input_cols();
+      const auto input_point_stride = input.ld_col * this->m_strat->get_output_cols() * this->m_args.stride_cols;
+      const auto n_output_pointers = this->m_strat->get_output_rows() * this->m_strat->get_output_cols();
+      const auto output_point_stride = output.ld_col * this->m_strat->get_output_cols();
+
+      // For each tile row, initialise the input and output pointer arrays. For
+      // each subsequent tile we simply update the pointers.
+      for (unsigned int tile_i = 0; tile_i < n_tile_rows; tile_i++)
+      {
+        const int input_i = static_cast<int>(output_i * this->m_args.stride_rows) - this->m_args.padding.top;
+        const int input_j = static_cast<int>(output_j * this->m_args.stride_cols) - this->m_args.padding.left;
+
+        fill_pointer_array<const TInput>(
+          ws->inptr_array, this->m_strat->get_input_rows(), this->m_strat->get_input_cols(),
+          input.base + input_i*input.ld_row + input_j*input.ld_col + input_channel_start,
+          input.ld_row, input.ld_col,
+          ws->input_buffer,
+          0, this->m_args.input_rows,
+          0, this->m_args.input_cols
         );
 
-        // Fill the input pointer array with padding values
-        for (auto index = 0u; index < m_strat->get_input_rows() * m_strat->get_input_cols(); index++)
+        // Compute the output pointer array
+        fill_pointer_array(
+          ws->outptr_array, this->m_strat->get_output_rows(), this->m_strat->get_output_cols(),
+          output.base + output_i*output.ld_row + output_j*output.ld_col + output_channel_start,
+          output.ld_row, output.ld_col,
+          ws->output_buffer,
+          0, this->m_args.output_rows,
+          0, this->m_args.output_cols
+        );
+
+        for (unsigned int tile_j = 0; tile_j < n_tile_cols; tile_j++)
         {
-          inptr_array[index] = input_buffer;
-        }
-
-        for (int start_out_j = 0; start_out_j < static_cast<int>(output_width);)
-        {
-          const int start_in_j = start_out_j * m_strat->get_stride_cols() - this->m_args.padding.left;
-          int pad_left = std::min(0, start_in_j);
-
-          // Compute how many output tiles we can compute with the direct kernel.
-          int n_direct_tiles = 0;
-          if (!pad_top  && !pad_bottom && !pad_left)
-          {
-            // Determine the maximum number of tiles we could handle.
-            n_direct_tiles = (output_width - start_out_j) / m_strat->get_output_cols();
-
-            // Continue to reduce this number as required to avoid reading
-            // padding on the right edge.
-            int end_in_j = start_in_j + n_direct_tiles * m_strat->get_input_cols();
-            int pad_right = std::max(0, end_in_j - static_cast<int>(input_width));
-
-            while (pad_right && n_direct_tiles)
-            {
-              n_direct_tiles--;
-              end_in_j -= m_strat->get_input_cols();
-              pad_right = std::max(0, end_in_j - static_cast<int>(input_width));
-            }
-          }
-
-          // Use the unpadded kernel if we can, otherwise use the padded one.
-          if (n_direct_tiles)
-          {
-            auto inptr = inptr_batch + start_in_i*ld_input_row + start_in_j*ld_input_col;
-            auto outptr = outptr_batch + start_out_i*ld_output_row + start_out_j*ld_output_col;
-            start_out_j += n_direct_tiles*m_strat->get_output_cols();
-
-#ifdef CYCLE_PROFILING
-            auto p = prof.ScopedProfiler(PROFILE_KERNEL, 0);
-#endif
-            m_strat->direct_kernel(1, n_direct_tiles,
-                                   inptr, ld_input_row, ld_input_col,
-                                   outptr, ld_output_row, ld_output_col,
-                                   parameters, this->m_args.input_channels,
-                                   &activation_min, &activation_max);
-            continue;
-          }
-
-          const int end_out_j = start_out_j + m_strat->get_output_cols();
-          const int end_in_j = start_in_j + m_strat->get_input_cols();
-
-          const auto pad_right = static_cast<unsigned int>(-std::min(static_cast<int>(input_width) - end_in_j, 0));
-          const unsigned int valid_output_cols = std::min(
-            end_out_j - start_out_j,
-            static_cast<int>(output_width) - start_out_j
+          // Invoke the indirect kernel for this tile
+          depthwise_depthfirst::Invoke<TInput, TWeight, TOutput, TAccum, OutputStage>::indirect(
+            strat, ws, os, parameters, m_bias, output_channel_end - output_channel_start
           );
-          pad_left *= -1;
-          // Construct the input pointer array - fill the array with pointers to
-          // the input buffer and then fill in the required values.
-          for (auto i = pad_top; i < m_strat->get_input_rows() - pad_bottom; i++)
-          {
-            // Can skip over the left padding because we will have either the
-            // same or less than the previous tile.
-            unsigned int j = pad_left;
-            const TInput *colptr = inptr_batch + (start_in_i + i) * ld_input_row + (start_in_j + j) * ld_input_col;
-            const void **ptrs = inptr_array + i * m_strat->get_input_cols() + j;
-            for (; j < m_strat->get_input_cols() - pad_right; j++)
-            {
-              *(ptrs++) = colptr;
-              colptr += ld_input_col;
-            }
-            for (; j < m_strat->get_input_cols(); j++)
-            {
-              *(ptrs++) = input_buffer;
-            }
-          }
 
-          // Construct the output pointer array.
-          void **outptr_pos = outptr_array;
-          for (auto i = 0u; i < valid_output_rows; i++)
+          // Progress the pointers
+          for (auto i = 0u; i < n_input_pointers; i++)
           {
-            unsigned int j = 0u;
-            TOutput *colptr = outptr_batch + (start_out_i + i) * ld_output_row + start_out_j * ld_output_col;
-            for (; j < valid_output_cols; j++)
-            {
-              *(outptr_pos++) = colptr;
-               colptr += ld_output_col;
-            }
-            for (; j < m_strat->get_output_cols(); j++)
-            {
-              *(outptr_pos++) = output_buffer;
-            }
+            ws->inptr_array[i] += input_point_stride;
           }
-          for (auto i = valid_output_rows; i < m_strat->get_output_rows(); i++)
+          for (auto i = 0u; i < n_output_pointers; i++)
           {
-            for (auto j = 0u; j < m_strat->get_output_cols(); j++)
-            {
-              *(outptr_pos++) = output_buffer;
-            }
+            ws->outptr_array[i] += output_point_stride;
           }
-
-          start_out_j += m_strat->get_output_cols();
-
-#ifdef CYCLE_PROFILING
-          // TODO Work number
-          auto p = prof.ScopedProfiler(PROFILE_KERNEL, (unsigned long)(0));
-#endif
-          m_strat->indirect_kernel(inptr_array, outptr_array, parameters,
-                                   this->m_args.input_channels,
-                                   &activation_min, &activation_max);
         }
+
+        output_i += this->m_strat->get_output_rows();
       }
     }
   }
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/depthwise_depthfirst_generic.hpp b/src/core/NEON/kernels/arm_conv/depthwise/depthwise_depthfirst_generic.hpp
index f04f775..9f53f7c 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/depthwise_depthfirst_generic.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/depthwise_depthfirst_generic.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -24,355 +24,277 @@
 
 #pragma once
 
-#include "src/core/NEON/kernels/arm_gemm/utils.hpp"
-
-#ifdef CYCLE_PROFILING
-#include "profiler.hpp"
-#endif
-
-#include <limits>
+#include "depthwise_depthfirst.hpp"
 
 namespace arm_conv {
 namespace depthwise {
 
-template <class Strategy, unsigned OutputRows, unsigned int OutputCols>
-class DepthwiseDepthfirstGenericBase :
-  public DepthwiseCommon<typename Strategy::input_type,
-                         typename Strategy::weight_type,
-                         typename Strategy::return_type>
+template <typename TInput, typename TOutput, typename TAccum>
+struct GenericDepthfirstKernelStrategyFunctionType
 {
-  protected:
+  using KernelType = std::function<void(const TInput *const *const, TOutput *const *const, const void *, const void *, const unsigned int, const unsigned int, const TAccum, const TAccum)>;
+};
 
-  using TInput = typename Strategy::input_type;
-  using TWeight = typename Strategy::weight_type;
-  using TOutput = typename Strategy::return_type;
-  using TAccum = typename Strategy::bias_type;
+template <typename TInput, typename TOutput>
+struct GenericDepthfirstKernelStrategyFunctionType<TInput, TOutput, int32_t>
+{
+  using KernelType = std::function<void(const TInput *const *const, TOutput *const *const, const void *, const arm_gemm::Requantize32 &, unsigned int, unsigned int)>;
+};
 
-  size_t sizeof_input_ptr_array(void) const
-  {
-    return sizeof(TInput *) * this->m_args.kernel_rows * this->m_args.kernel_cols * Strategy::n_output_points;
-  }
-
-  size_t sizeof_input_buffer(unsigned int n_channels) const
-  {
-    const unsigned int vl = arm_gemm::utils::get_vector_length<TInput>(Strategy::vl_type);
-    const auto rounded_channels = arm_gemm::roundup(n_channels, vl);
-    return sizeof(TInput) * rounded_channels;
-  }
-
-  size_t sizeof_output_buffer(unsigned int n_channels) const
-  {
-    const unsigned int vl = arm_gemm::utils::get_vector_length<TOutput>(Strategy::vl_type);
-    const auto rounded_channels = arm_gemm::roundup(n_channels, vl);
-    return sizeof(TOutput) * rounded_channels;
-  }
-
-  unsigned int input_rows(void) const
-  {
-    return this->m_args.kernel_rows + (OutputRows - 1)*this->m_args.stride_rows;
-  }
-
-  unsigned int input_cols(void) const
-  {
-    return this->m_args.kernel_cols + (OutputCols - 1)*this->m_args.stride_cols;
-  }
-
-  void execute_tiles(
-    std::function<void(const TInput *const *, TOutput *const *)> tile_fn,
-    std::function<void(TInput *, unsigned int)> initialise_input_buffer,
-    const unsigned int batches,
-    const unsigned int input_height,
-    const unsigned int input_width,
-    const unsigned int input_channels,
-    const PaddingValues &padding,
-    const void *const _input,
-    const size_t ld_input_col,
-    const size_t ld_input_row,
-    const size_t ld_input_batch,
-    const unsigned int output_height,
-    const unsigned int output_width,
-    void *const _output,
-    const size_t ld_output_col,
-    const size_t ld_output_row,
-    const size_t ld_output_batch,
-    void *const _working_space,
-    const unsigned int thread_id,
-    const unsigned int n_threads
-  ) const
-  {
-    static_assert(OutputRows * OutputCols <= Strategy::n_output_points,
-                  "Too many output points for kernel.");
-
-    // Determine what portion of the work to do.
-    const unsigned int n_rows_per_thread = arm_gemm::iceildiv(output_height, n_threads);
-    const int start_out_height = std::min(thread_id * n_rows_per_thread, output_height);
-    const int end_out_height = std::min(start_out_height + n_rows_per_thread, output_height);
-
-    // Cast input and output pointers into the right types
-    const TInput *const inptr = static_cast<const TInput *>(_input);
-    TOutput *const outptr = static_cast<TOutput *>(_output);
-
-    // Allocate portions of the working space
-    uint8_t *const working_space = static_cast<uint8_t *>(_working_space) + this->get_working_size(thread_id, input_channels);
-    const TInput **const inptr_array = reinterpret_cast<const TInput **>(working_space);
-    TOutput *const output_buffer = reinterpret_cast<TOutput *>(working_space + this->sizeof_input_ptr_array());
-    TInput *const input_buffer = reinterpret_cast<TInput *>(working_space + this->sizeof_input_ptr_array() + this->sizeof_output_buffer(input_channels * this->m_args.channel_multiplier));
-
-    // Create an array for the output pointers
-    TOutput * _outptr_array[Strategy::n_output_points];
-    TOutput **const outptr_array = _outptr_array;
-
-    // Initialise the input buffer
-    initialise_input_buffer(input_buffer, input_channels);
-
-    // For each output tile, construct the requisite set of pointers and call
-    // into the kernel.
-    for (unsigned int batch = 0; batch < batches; batch++)
-    {
-      // Get batch pointers
-      const auto inptr_batch = inptr + batch * ld_input_batch;
-      const auto outptr_batch = outptr + batch * ld_output_batch;
-
-      for (int start_out_i = start_out_height;
-           start_out_i < end_out_height;
-           start_out_i += static_cast<int>(OutputRows))
-      {
-        const int end_out_i = std::min(start_out_i + OutputRows,
-                                       output_height);
-
-        for (int start_out_j = 0;
-             start_out_j < static_cast<int>(output_width);
-             start_out_j += static_cast<int>(OutputCols))
-        {
-          const int end_out_j = std::min(start_out_j + OutputCols,
-                                         output_width);
-
-          // Fill the pointer arrays with pointers to the input/output buffers.
-          for (auto index = 0u;
-               index < (Strategy::n_output_points * this->m_args.kernel_rows * this->m_args.kernel_cols);
-               index++)
-          {
-            inptr_array[index] = input_buffer;
-          }
-          for (auto index = 0u; index < Strategy::n_output_points; index++)
-          {
-            outptr_array[index] = output_buffer;
-          }
-
-          // Construct the pointer arrays together. Note that the input pointer
-          // array is striped. Since the array has already been filled with
-          // pointers to the padding array we merely fill in the valid points
-          // as we get to them.
-          unsigned int output_index = 0;
-          auto outptr_row = outptr_batch + start_out_i * ld_output_row + start_out_j * ld_output_col;
-          for (auto out_i = start_out_i; out_i < end_out_i; out_i++)
-          {
-            auto outptr_col = outptr_row;
-
-            // Compute the padding for this row of tiles.
-            const int start_in_i = out_i * this->m_args.stride_rows - padding.top;
-            const int end_in_i = start_in_i + this->m_args.kernel_rows;
-            const auto pad_top = static_cast<unsigned int>(std::max<int>(0, 0 - start_in_i));
-            const auto pad_bottom = static_cast<unsigned int>(std::max<int>(0, end_in_i - input_height));
-            const unsigned int valid_rows = this->m_args.kernel_rows - pad_top - pad_bottom;
-
-            for (auto out_j = start_out_j; out_j < end_out_j; out_j++, output_index++)
-            {
-              // Compute the output pointer.
-              outptr_array[output_index] = outptr_col;
-              outptr_col += ld_output_col;
-
-              // Compute the padding for this tile.
-              const int start_in_j = out_j * this->m_args.stride_cols - padding.left;
-              const int end_in_j = start_in_j + this->m_args.kernel_cols;
-              const auto pad_left = static_cast<unsigned int>(std::max<int>(0, 0 - start_in_j));
-              const auto pad_right = static_cast<unsigned int>(std::max<int>(0, end_in_j - input_width));
-              const unsigned int valid_cols = this->m_args.kernel_cols - pad_left - pad_right;
-
-              // Hence compute the input pointers.
-              auto input_index = output_index + Strategy::n_output_points * (pad_top * this->m_args.kernel_cols + pad_left);
-              auto inptr_row = inptr_batch + (start_in_i + pad_top) * ld_input_row + (start_in_j + pad_left) * ld_input_col;
-              for (auto in_i = 0u; in_i < valid_rows; in_i++)
-              {
-                auto inptr_col = inptr_row;
-                auto input_index_col = input_index;
-
-                for (auto in_j = 0u; in_j < valid_cols; in_j++)
-                {
-                  inptr_array[input_index_col] = inptr_col;
-                  inptr_col += ld_input_col;
-                  input_index_col += Strategy::n_output_points;
-                }
-
-                inptr_row += ld_input_row;
-                input_index += Strategy::n_output_points * this->m_args.kernel_cols;
-              }
-            }
-
-            outptr_row += ld_output_row;
-          }
-
-          tile_fn(inptr_array, outptr_array);
-        }
-      }
-    }
-  }
+template <typename TInput, typename TWeight, typename TOutput, typename TAccum>
+class GenericDepthfirstKernelStrategy
+{
+  unsigned int m_n_output_points;
+  arm_gemm::VLType m_vl_type;
+  unsigned int m_accumulator_depth_vl;
 
   public:
-  DepthwiseDepthfirstGenericBase(const DepthwiseArgs &args) : DepthwiseCommon<TInput, TWeight, TOutput>(args)
+  GenericDepthfirstKernelStrategy(unsigned int n_output_points, arm_gemm::VLType vl_type, unsigned int accumulator_depth_vl=1)
+  : m_n_output_points(n_output_points), m_vl_type(vl_type), m_accumulator_depth_vl(accumulator_depth_vl)
   {
   }
 
-  DepthwiseDepthfirstGenericBase(DepthwiseDepthfirstGenericBase &) = delete;
-  DepthwiseDepthfirstGenericBase &operator=(DepthwiseDepthfirstGenericBase &) = delete;
+  virtual ~GenericDepthfirstKernelStrategy() = default;
 
-  size_t get_storage_size(void) const override
+  virtual arm_gemm::VLType get_vl_type() const { return m_vl_type; }
+  virtual unsigned int get_accumulator_depth_vl() const { return m_accumulator_depth_vl; }
+  virtual unsigned int get_n_output_points() const { return m_n_output_points; }
+
+  using KernelType = typename GenericDepthfirstKernelStrategyFunctionType<TInput, TOutput, TAccum>::KernelType;
+  virtual KernelType get_kernel(void) const = 0;
+};
+
+template <typename TInput,
+          typename TWeight=TInput,
+          typename TOutput=TInput,
+          typename TAccum=typename DefaultTAccum<TInput>::Type,
+          typename OutputStage=typename DefaultOutputStage<TOutput>::Type>
+class GenericDepthfirstStrategy : public DepthwiseDepthfirstStrategyCommon<TInput, TWeight, TOutput, TAccum, OutputStage>
+{
+  protected:
+  using KernelStrategyType = GenericDepthfirstKernelStrategy<TInput, TWeight, TOutput, TAccum>;
+  std::unique_ptr<KernelStrategyType> m_strategy;
+
+  public:
+  GenericDepthfirstStrategy(
+    KernelStrategyType *strat, unsigned int n_output_rows, unsigned int n_output_cols,
+    const DepthwiseArgs &args
+  )
+  : DepthwiseDepthfirstStrategyCommon<TInput, TWeight, TOutput, TAccum, OutputStage>(
+      n_output_rows, n_output_cols,
+      args.kernel_rows, args.kernel_cols,
+      args.stride_rows, args.stride_cols
+    ),
+    m_strategy(strat)
   {
-    const unsigned int vl = arm_gemm::utils::get_vector_length<TAccum>(Strategy::vl_type);
-    const auto rounded_channels = arm_gemm::roundup(this->m_args.input_channels, vl);
-    return (this->m_args.kernel_rows * this->m_args.kernel_cols) * rounded_channels * sizeof(TWeight);
   }
 
-  void pack_parameters(void *_buffer, const void *, const void *_weights, size_t ld_weight_col, size_t ld_weight_row) override
+  GenericDepthfirstStrategy(GenericDepthfirstStrategy &) = delete;
+  GenericDepthfirstStrategy operator=(GenericDepthfirstStrategy &) = delete;
+
+  arm_gemm::VLType get_vl_type(void) const override { return m_strategy->get_vl_type(); }
+  unsigned int get_accumulator_depth_vl(void) const override { return m_strategy->get_accumulator_depth_vl(); }
+
+  size_t get_storage_size(const DepthwiseArgs &args) const override
   {
-    // Cast the pointers
-    TWeight *buffer = static_cast<TWeight *>(_buffer);
-    const TWeight *const weights = static_cast<const TWeight *>(_weights);
-
-    const unsigned int vl = arm_gemm::utils::get_vector_length<TAccum>(Strategy::vl_type);
-    ld_weight_col = (ld_weight_col == 0) ? this->m_args.input_channels : ld_weight_col;
-    ld_weight_row = (ld_weight_row == 0) ? this->m_args.kernel_cols * ld_weight_col : ld_weight_row;
-
-    for (unsigned int n = 0; n < this->m_args.input_channels; n += vl)
-    {
-      const unsigned int todo = std::min(vl, this->m_args.input_channels - n);
-
-      // Copy each of the weights in turn
-      auto weights_row = weights + n;
-      for (unsigned int i = 0; i < this->m_args.kernel_rows; i++)
-      {
-        auto weights_col = weights_row;
-
-        for (unsigned int j = 0; j < this->m_args.kernel_cols; j++)
-        {
-          for (unsigned int m = 0; m < todo; m++)
-          {
-            buffer[m] = weights_col[m];
-          }
-          buffer += vl;
-
-          weights_col += ld_weight_col;
-        }
-
-        weights_row += ld_weight_row;
-      }
-    }
+    interleaves::PackingArguments packing_args(
+      this->get_kernel_rows(), this->get_kernel_cols(), sizeof(TWeight),
+      false, sizeof(TAccum),  // Don't pack the bias
+      this->get_vl_type(), sizeof(TAccum), this->get_accumulator_depth_vl(),
+      [this] (unsigned int idx, unsigned int &x, unsigned int &y) -> bool
+      { return this->get_kernel_packing_point(idx, x, y); }
+    );
+    return interleaves::get_storage_size_generic(packing_args, args);
   }
 
-  size_t get_working_size(const unsigned int n_threads, const unsigned int n_channels) const override
+  void pack_parameters(
+    const DepthwiseArgs &args, void *buffer,
+    const void *biases, const OutputStage &,
+    const void *weights, size_t ld_weight_col, size_t ld_weight_row
+  ) const override
   {
-    const unsigned int n_output_channels = n_channels * this->m_args.channel_multiplier;
-    return n_threads * (sizeof_input_ptr_array() +
-                        sizeof_output_buffer(n_output_channels) +
-                        sizeof_input_buffer(n_channels));
+    interleaves::PackingArguments packing_args(
+      this->get_kernel_rows(), this->get_kernel_cols(), sizeof(TWeight),
+      false, sizeof(TAccum),  // Don't pack the bias
+      this->get_vl_type(), sizeof(TAccum), this->get_accumulator_depth_vl(),
+      [this] (unsigned int idx, unsigned int &x, unsigned int &y) -> bool
+      { return this->get_kernel_packing_point(idx, x, y); }
+    );
+    interleaves::pack_parameters_generic(
+      packing_args, args, buffer, biases, weights, ld_weight_col, ld_weight_row);
+  }
+
+  const typename KernelStrategyType::KernelType get_kernel() const { return m_strategy->get_kernel(); }
+};
+
+// Use a templated function to marshal arguments when executing the kernel.
+template <typename OutputStage> struct DepthwiseDepthfirstGenericKernelCall;
+
+template <>
+struct DepthwiseDepthfirstGenericKernelCall<Nothing>
+{
+  template <typename StratType, typename WorkspaceType, typename TAccum>
+  static void execute(
+    const StratType *strat, const WorkspaceType *ws, const Nothing &,
+    const TAccum *bias, const void *params,
+    const unsigned int n_kernel_points, const unsigned int n_output_channels
+  )
+  {
+    strat->get_kernel()(
+      ws->inptr_array,
+      ws->outptr_array,
+      params, bias,
+      n_kernel_points, n_output_channels,
+      ws->activation_min, ws->activation_max
+    );
   }
 };
 
-template <class Strategy, unsigned OutputRows, unsigned int OutputCols>
-class DepthwiseDepthfirstGeneric : public DepthwiseDepthfirstGenericBase<Strategy, OutputRows, OutputCols>
+template <>
+struct DepthwiseDepthfirstGenericKernelCall<arm_gemm::Requantize32>
 {
-  using Parent = DepthwiseDepthfirstGenericBase<Strategy, OutputRows, OutputCols>;
-  using TInput = typename Parent::TInput;
-  using TWeight = typename Parent::TWeight;
-  using TAccum = typename Parent::TAccum;
-  using TOutput = typename Parent::TOutput;
+  template <typename StratType, typename WorkspaceType>
+  static void execute(
+    const StratType *strat, const WorkspaceType *ws, const arm_gemm::Requantize32 &qp,
+    const int32_t *, const void *params,
+    const unsigned int n_kernel_points, const unsigned int n_output_channels
+  )
+  {
+    strat->get_kernel()(
+      ws->inptr_array,
+      ws->outptr_array,
+      params, qp,
+      n_kernel_points, n_output_channels
+    );
+  }
+};
 
+
+/* Workspace Element for an array of input pointers as consumed by the
+ * "Generic" depthwise kernels.
+ */
+template <typename T>
+class GenericInputArrayElement
+{
+  public:
+  struct Workspace
+  {
+    const T **inptr_array;
+  };
+
+  template <class OutputStage>
+  static size_t get_element_size(const WorkspaceArgs<IDepthfirstStrategy, OutputStage> &args)
+  {
+    const auto kernel_points = args.depthwise_args.kernel_rows * args.depthwise_args.kernel_cols;
+    return sizeof(T **) * args.strategy->get_input_rows() * args.strategy->get_input_cols() * kernel_points;
+  }
+
+  template <class WorkspaceType, class OutputStage>
+  static void *initialise(WorkspaceType *ws, void *buffer, const WorkspaceArgs<IDepthfirstStrategy, OutputStage> &args)
+  {
+    ws->inptr_array = reinterpret_cast<const T**>(buffer);
+    return reinterpret_cast<char *>(buffer) + get_element_size(args);
+  }
+};
+
+template <typename TInput, typename TWeight=TInput, typename TOutput=TInput,
+          typename TAccum=typename DefaultTAccum<TInput>::Type,
+          typename OutputStage=typename DefaultOutputStage<TOutput>::Type>
+class DepthwiseDepthfirstGeneric : public DepthwiseDepthfirstCommon<TInput, TWeight, TOutput, TAccum, OutputStage>
+{
+  using StratType = GenericDepthfirstStrategy<TInput, TWeight, TOutput, TAccum, OutputStage>;
+  using Parent = DepthwiseDepthfirstCommon<TInput, TWeight, TOutput, TAccum, OutputStage>;
+  using WorkspaceManager = Workspace<
+    OutputArrayElement<TOutput>,
+    GenericInputArrayElement<TInput>,
+    InputBufferElement<TInput>,
+    ActivationsElement<TAccum, OutputStage>
+  >;
+  using WorkingSpace = typename WorkspaceManager::WorkspaceType;
   const TAccum *m_bias = nullptr;
 
   public:
-  DepthwiseDepthfirstGeneric(const DepthwiseArgs &args) : Parent(args)
+  DepthwiseDepthfirstGeneric(StratType *const strat, const DepthwiseArgs &args, const OutputStage &os={})
+  : Parent(strat, args, os)
   {
   }
 
   DepthwiseDepthfirstGeneric(DepthwiseDepthfirstGeneric &) = delete;
   DepthwiseDepthfirstGeneric &operator=(DepthwiseDepthfirstGeneric &) = delete;
 
-  void pack_parameters(void *buffer, const void *bias, const void *weights, size_t ld_weight_col, size_t ld_weight_row) override
+  void pack_parameters(
+    void *buffer, const void *biases,
+    const void *weights, size_t ld_weight_col, size_t ld_weight_row
+  ) override
   {
-    m_bias = static_cast<const TAccum *>(bias);
-    Parent::pack_parameters(buffer, bias, weights, ld_weight_col, ld_weight_row);
+    Parent::pack_parameters(buffer, biases, weights, ld_weight_col, ld_weight_row);
+    m_bias = reinterpret_cast<const TAccum *>(biases);  // Get a copy of the biases
+    depthwise_depthfirst::stash_bias(this->get_output_stage(), m_bias);
   }
 
-  using DepthwiseDepthfirstGenericBase<Strategy, OutputRows, OutputCols>::execute;
-  void execute(
-    const unsigned int batches,
-    const unsigned int input_height,
-    const unsigned int input_width,
-    const unsigned int input_channels,
-    const PaddingValues &padding,
-    const void *const _input,
-    const size_t ld_input_col,
-    const size_t ld_input_row,
-    const size_t ld_input_batch,
-    const void *const parameters,
-    const unsigned int output_height,
-    const unsigned int output_width,
-    void *const _output,
-    const size_t ld_output_col,
-    const size_t ld_output_row,
-    const size_t ld_output_batch,
-    void *const _working_space,
-    const unsigned int thread_id,
-    const unsigned int n_threads
+  size_t get_working_size_per_thread(const unsigned int n_input_channels) const override
+  {
+    DepthwiseArgs args(this->m_args);
+    args.input_channels = n_input_channels;
+    return WorkspaceManager::get_sizeof_workspace(WorkspaceArgs<IDepthfirstStrategy, OutputStage>(this->m_strat.get(), args, this->get_output_stage()));
+  }
+
+  void initialise_working_space(void *buffer, unsigned int n_input_channels) const override
+  {
+    DepthwiseArgs args(this->m_args);
+    args.input_channels = n_input_channels;
+    return WorkspaceManager::initialise(buffer, WorkspaceArgs<IDepthfirstStrategy, OutputStage>(this->m_strat.get(), args, this->get_output_stage()));
+  }
+
+  protected:
+  void compute_tile_padded(
+    unsigned int output_i, unsigned int output_j,
+    unsigned int channel_start, unsigned int channel_end,
+    const TensorSpec<const TInput *> &input,
+    const TensorSpec<TOutput *> &output,
+    const void *parameters,
+    void *working_space_raw
   ) const override
   {
-    Strategy strat(this->m_args.cpu_info);
-#ifdef CYCLE_PROFILING
-    arm_gemm::profiler prof;
-#endif
+    // Get the working space
+    WorkingSpace *ws = reinterpret_cast<WorkingSpace *>(working_space_raw);
 
-    // Compute activation values
-    TAccum activation_min, activation_max;
-    std::tie(activation_min, activation_max) = get_default_activation_values<TAccum>();
+    const int ii = static_cast<int>(output_i * this->m_args.stride_rows) - this->m_args.padding.top;
+    const auto input_pad_top = static_cast<unsigned int>(ii < 0 ? -ii : 0);
+    const auto input_i = static_cast<unsigned int>(ii < 0 ? 0 : ii);
 
-    switch (this->m_args.activation.type)
-    {
-      case arm_gemm::Activation::Type::BoundedReLU:
-        activation_max = static_cast<TAccum>(this->m_args.activation.param1);
-        // Fall through
-      case arm_gemm::Activation::Type::ReLU:
-        activation_min = static_cast<TAccum>(0);
-        break;
-      default:
-        break;
-    }
+    const int ij = static_cast<int>(output_j * this->m_args.stride_cols) - this->m_args.padding.left;
+    const auto input_pad_left = static_cast<unsigned int>(ij < 0 ? -ij : 0);
+    const auto input_j = static_cast<unsigned int>(ij < 0 ? 0 : ij);
 
-    // Create a function to initialise the input buffer
-    const auto initialise_input_buffer = [] (TInput *const buffer, const unsigned int n) {
-      std::memset(buffer, 0, n * sizeof(TInput));
-    };
+    fill_pointer_array_generic_kernel<const TInput>(
+      ws->inptr_array,
+      this->m_strat->get_output_rows(), this->m_strat->get_output_cols(),
+      this->m_args.kernel_rows, this->m_args.kernel_cols,
+      this->m_args.stride_rows, this->m_args.stride_cols,
+      input.base + input_i*input.ld_row + input_j*input.ld_col + channel_start,
+      input.ld_row, input.ld_col,
+      ws->input_buffer,
+      input_pad_top, this->m_args.input_rows - input_i,
+      input_pad_left, this->m_args.input_cols - input_j
+    );
 
-    // Create a function to execute a tile of work
-    const auto tile_fn = [&] (const TInput *const *const inptrs, TOutput *const * const outptrs) {
-#ifdef CYCLE_PROFILING
-      auto p = prof.ScopedProfiler(
-        PROFILE_KERNEL,
-        (unsigned long) (OutputRows * OutputCols * this->m_args.kernel_rows* this->m_args.kernel_cols)
-      );
-#endif
-      strat.kernel(inptrs, outptrs, parameters, m_bias,
-                   this->m_args.kernel_rows * this->m_args.kernel_cols,
-                   this->m_args.input_channels, activation_min, activation_max);
-    };
+    // Compute the output pointer array
+    fill_pointer_array<TOutput>(
+      ws->outptr_array, this->m_strat->get_output_rows(), this->m_strat->get_output_cols(),
+      output.base + output_i*output.ld_row + output_j*output.ld_col + channel_start,
+      output.ld_row, output.ld_col,
+      ws->output_buffer,
+      0, this->m_args.output_rows - output_i, // Top padding, # valid rows
+      0, this->m_args.output_cols - output_j  // Left padding, # valid columns
+    );
 
-    // Call into a parent utility function to do the actual work.
-    Parent::execute_tiles(
-      tile_fn, initialise_input_buffer,
-      batches, input_height, input_width, input_channels, padding,
-      _input, ld_input_col, ld_input_row, ld_input_batch,
-      output_height, output_width,
-      _output, ld_output_col, ld_output_row, ld_output_batch,
-      _working_space, thread_id, n_threads
+    // Execute the kernel
+    DepthwiseDepthfirstGenericKernelCall<OutputStage>::execute(
+      reinterpret_cast<const StratType *>(this->m_strat.get()), ws,
+      this->get_output_stage(), m_bias, parameters,
+      this->m_args.kernel_rows * this->m_args.kernel_cols,
+      channel_end - channel_start
     );
   }
 };
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/depthwise_depthfirst_multiplier.hpp b/src/core/NEON/kernels/arm_conv/depthwise/depthwise_depthfirst_multiplier.hpp
index 2862361..e58467b 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/depthwise_depthfirst_multiplier.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/depthwise_depthfirst_multiplier.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -24,7 +24,8 @@
 
 #pragma once
 
-#include "src/core/NEON/kernels/arm_gemm/utils.hpp"
+#include "depthwise_depthfirst.hpp"
+#include "interleaves/generic_quantized_dot_product.hpp"
 
 #ifdef CYCLE_PROFILING
 #include "profiler.hpp"
@@ -35,492 +36,559 @@
 namespace arm_conv {
 namespace depthwise {
 
-namespace common
+template <typename TInput, typename TWeight, typename TOutput, typename TAccum>
+class DepthfirstMultiplierStrategy : public DepthwiseDepthfirstStrategyCommon<TInput, TWeight, TOutput, TAccum, Nothing>
 {
-  template <typename strategy, typename F>
-  void depthwise_multiplier_execute(
-    const F execute_tile,
-    typename strategy::input_type pad_value,
-    const DepthwiseArgs &args,
-    const unsigned int batches,
-    const unsigned int input_height,
-    const unsigned int input_width,
-    const unsigned int input_channels,
-    const PaddingValues &padding,
-    const void *const _input,
-    const size_t ld_input_col,
-    const size_t ld_input_row,
-    const size_t ld_input_batch,
-    const void *const parameters,
-    const size_t param_stride,
-    const unsigned int output_height,
-    const unsigned int output_width,
-    void *const _output,
-    const size_t ld_output_col,
-    const size_t ld_output_row,
-    const size_t ld_output_batch,
-    void *const _working_space,
-    const unsigned int thread_id,
-    const unsigned int n_threads
-  )
+  using Parent = DepthwiseDepthfirstStrategyCommon<TInput, TWeight, TOutput, TAccum, Nothing>;
+
+  protected:
+  virtual interleaves::PackingArguments get_packing_args(const DepthwiseArgs &args) const
   {
-    using TInput = typename strategy::input_type;
-    using TOutput = typename strategy::return_type;
-
-    // Determine what portion of the work to do.
-    const unsigned int n_rows_per_thread = arm_gemm::iceildiv(output_height, n_threads);
-    const int start_out_height = std::min(thread_id * n_rows_per_thread, output_height);
-    const int end_out_height = std::min(start_out_height + n_rows_per_thread, output_height);
-
-    // Cast input and output pointers into the right types
-    const TInput *const inptr = static_cast<const TInput *>(_input);
-    TOutput *const outptr = static_cast<TOutput *>(_output);
-
-    // To simplify the kernel, we process padded or non-NCHW-ordered input into
-    // a form which can be consumed by the kernel. This data is stored here and
-    // passed into the kernel as an array of N pointers (one per row of the
-    // input).
-    TInput rearranged_input[strategy::input_rows][strategy::input_col_quads*(16 / sizeof(TInput))];
-    const TInput *inptrs[strategy::input_rows];
-
-    // Create an array for the output pointers
-    TOutput * _outptr_array[strategy::output_rows * strategy::output_cols];
-    TOutput **const outptr_array = _outptr_array;
-
-    // Allocate portions of the working space
-    uint8_t *const working_space = static_cast<uint8_t *>(_working_space);
-    TOutput *const output_buffer = reinterpret_cast<TOutput *>(working_space);
-
-    // For each output tile, construct the requisite set of pointers and call
-    // into the kernel.
-    for (unsigned int batch = 0; batch < batches; batch++)
-    {
-      // Get batch pointers
-      const auto inptr_batch = inptr + batch * ld_input_batch;
-      const auto outptr_batch = outptr + batch * ld_output_batch;
-
-      for (int start_out_i = start_out_height;
-           start_out_i < end_out_height;
-           start_out_i += static_cast<int>(strategy::output_rows))
+    return interleaves::PackingArguments(
+      args.kernel_rows, args.kernel_cols, sizeof(TWeight),
+      true, sizeof(TAccum),
+      this->get_vl_type(),
+      sizeof(TAccum), 1,
+      [args] (unsigned int pos, unsigned int &x, unsigned int &y) -> bool
       {
-        const int end_out_i = start_out_i + strategy::output_rows;
-        const int start_in_i = start_out_i * strategy::stride_rows - padding.top;
-        const int end_in_i = start_in_i + strategy::input_rows;
-
-        // Compute top/bottom padding
-        const auto pad_top = static_cast<unsigned int>(-std::min(start_in_i, 0));
-        const auto pad_bottom = static_cast<unsigned int>(-std::min(static_cast<int>(input_height) - end_in_i, 0));
-        const unsigned int valid_output_rows = std::min(
-          end_out_i - start_out_i,
-          static_cast<int>(output_height) - start_out_i
-        );
-
-        for (int start_out_j = 0; start_out_j < static_cast<int>(output_width);)
+        if (pos < args.kernel_rows * args.kernel_cols)
         {
-          const int start_in_j = start_out_j * strategy::stride_cols - args.padding.left;
-          const int pad_left = -std::min(0, start_in_j);
-
-          const int end_out_j = start_out_j + strategy::output_cols;
-          const int end_in_j = start_in_j + strategy::input_cols;
-
-          const auto pad_right = static_cast<unsigned int>(-std::min(static_cast<int>(input_width) - end_in_j, 0));
-          const unsigned int valid_output_cols = std::min(
-            end_out_j - start_out_j,
-            static_cast<int>(output_width) - start_out_j
-          );
-
-          // Construct the output pointer array.
-          TOutput **outptr_pos = outptr_array;
-          for (auto i = 0u; i < valid_output_rows; i++)
-          {
-            unsigned int j = 0u;
-            TOutput *colptr = outptr_batch + (start_out_i + i) * ld_output_row + start_out_j * ld_output_col;
-            for (; j < valid_output_cols; j++)
-            {
-              *(outptr_pos++) = colptr;
-               colptr += ld_output_col;
-            }
-            for (; j < strategy::output_cols; j++)
-            {
-              *(outptr_pos++) = output_buffer;
-            }
-          }
-          for (auto i = valid_output_rows; i < strategy::output_rows; i++)
-          {
-            for (auto j = 0u; j < strategy::output_cols; j++)
-            {
-              *(outptr_pos++) = output_buffer;
-            }
-          }
-
-          start_out_j += strategy::output_cols;
-
-          const uint8_t *params = static_cast<const uint8_t *>(parameters);
-
-          // Loop over the input channels
-          for (unsigned int in_c = 0; in_c < input_channels; in_c++)
-          {
-            // Construct the input array - first fill with padding values and
-            // then fill in correct values.
-            for (unsigned int i = 0; i < strategy::input_rows; i++)
-            {
-              for (unsigned int j = 0;
-                   j < (16 / sizeof(TInput)) * strategy::input_col_quads; j++)
-              {
-                rearranged_input[i][j] = pad_value;
-              }
-              inptrs[i] = rearranged_input[i];
-            }
-
-            auto inptr_row = inptr_batch + in_c +
-                             (start_in_i + pad_top) * ld_input_row +
-                             (start_in_j + pad_left) * ld_input_col;
-            if (ld_input_col == 1 && !pad_left &&
-                start_in_j + (16 / sizeof(TInput)) * strategy::input_col_quads < input_width)
-            {
-              // The input tensor is already in NCHW format, and we're reading
-              // an unpadded section of it - allow the kernel to read it
-              // directly.
-              for (unsigned int i = pad_top; i < strategy::input_rows - pad_bottom; i++)
-              {
-                inptrs[i] = inptr_row;
-                inptr_row += ld_input_row;
-              }
-            }
-            else
-            {
-              // Either the input tensor isn't in NCHW format, or we're reading
-              // a padded section. Copy the relevant portion of the input here
-              // and allow the kernel to read this.
-              for (unsigned int i = pad_top; i < strategy::input_rows - pad_bottom; i++)
-              {
-                auto inptr_col = inptr_row;
-                for (unsigned int j = pad_left; j < strategy::input_cols - pad_right; j++)
-                {
-                  rearranged_input[i][j] = *inptr_col;
-                  inptr_col += ld_input_col;
-                }
-                inptr_row += ld_input_row;
-              }
-            }
-
-            execute_tile(inptrs, outptr_array, params);
-
-            // Progress the output pointers
-            TOutput **outptr_pos = outptr_array;
-            for (auto i = 0u; i < strategy::output_rows * strategy::output_cols; i++)
-            {
-              outptr_pos[i] += args.channel_multiplier;
-            }
-
-            // Progress the pointer into the parameters
-            params += param_stride;
-          }
+          y = pos % args.kernel_cols;
+          x = pos / args.kernel_cols;
+          return true;
         }
+        return false;
       }
-    }
-  }
-}
-
-template <class strategy>
-class DepthwiseDepthfirstWithMultiplier :
-  public DepthwiseCommon<typename strategy::input_type,
-                         typename strategy::weight_type,
-                         typename strategy::return_type>
-{
-  using TInput = typename strategy::input_type;
-  using TWeight = typename strategy::weight_type;
-  using TOutput = typename strategy::return_type;
-  using TAccum = typename strategy::bias_type;
-
-  size_t sizeof_output_buffer(unsigned int n_channels) const
-  {
-    const unsigned int vl = arm_gemm::utils::get_vector_length<TOutput>(strategy::vl_type);
-    const auto rounded_channels = arm_gemm::roundup(n_channels, vl);
-    return sizeof(TOutput) * rounded_channels;
+    );
   }
 
   public:
-  DepthwiseDepthfirstWithMultiplier(const DepthwiseArgs &args) : DepthwiseCommon<TInput, TWeight, TOutput>(args)
+  using Parent::Parent;
+
+  size_t get_storage_size(const DepthwiseArgs &args) const override
+  {
+    return interleaves::get_storage_size_generic(this->get_packing_args(args), args);
+  }
+
+  void pack_parameters(const DepthwiseArgs &args, void *buffer, const void *biases, const Nothing &, const void *weights, size_t ld_weight_col, size_t ld_weight_row) const override
+  {
+    interleaves::pack_parameters_generic(
+      this->get_packing_args(args), args,
+      buffer, biases, weights, ld_weight_col, ld_weight_row
+    );
+  }
+
+  using KernelType = std::function<void(
+    const TInput *const *,  // Input pointers
+    TOutput *const *,  // Output pointers
+    const void *,  // Ravelled bias, weights, and quantization parameters
+    unsigned int,  // # output channels
+    TAccum, TAccum  // Min and max activation clamps
+  )>;
+  virtual KernelType get_kernel(void) const = 0;
+};
+
+
+template <typename TInput, typename TWeight, typename TOutput>
+class DepthfirstMultiplierStrategy<TInput, TWeight, TOutput, int32_t> : public DepthwiseDepthfirstStrategyCommon<TInput, TWeight, TOutput, int32_t, arm_gemm::Requantize32>
+{
+  using Parent = DepthwiseDepthfirstStrategyCommon<TInput, TWeight, TOutput, int32_t, arm_gemm::Requantize32>;
+
+  public:
+  using Parent::Parent;
+
+  size_t get_storage_size(const DepthwiseArgs &args) const override
+  {
+    return interleaves::quantized::get_storage_size(args, this->get_vl_type(), this->get_accumulator_depth_vl());
+  }
+
+  void pack_parameters(const DepthwiseArgs &args, void *buffer, const void *biases, const arm_gemm::Requantize32 &qp, const void *weights, size_t ld_weight_col, size_t ld_weight_row) const override
+  {
+    interleaves::quantized::pack_parameters<TWeight>(
+      buffer, reinterpret_cast<const int32_t *>(biases),
+      reinterpret_cast<const TWeight *>(weights), ld_weight_col, ld_weight_row,
+      args, qp, this->get_vl_type(), this->get_accumulator_depth_vl()
+    );
+  }
+
+  using KernelType = std::function<void(
+    const TInput *const *,  // Input pointers
+    TOutput *const *,  // Output pointers
+    const void *,  // Ravelled bias, weights, and quantization parameters
+    unsigned int,  // # output channels
+    const arm_gemm::Requantize32 &
+  )>;
+  virtual KernelType get_kernel(void) const = 0;
+};
+
+
+template <typename TInput, typename TWeight, typename TOutput, typename TAccum>
+class GenericDepthfirstMultiplierKernelStrategy
+{
+  const arm_gemm::VLType m_vl_type;
+  const unsigned int m_output_rows, m_output_cols;
+
+  public:
+  GenericDepthfirstMultiplierKernelStrategy(unsigned int output_rows, unsigned int output_cols, arm_gemm::VLType vl_type)
+  : m_vl_type(vl_type), m_output_rows(output_rows), m_output_cols(output_cols)
   {
   }
 
-  DepthwiseDepthfirstWithMultiplier(DepthwiseDepthfirstWithMultiplier &) = delete;
-  DepthwiseDepthfirstWithMultiplier &operator=(DepthwiseDepthfirstWithMultiplier &) = delete;
+  virtual ~GenericDepthfirstMultiplierKernelStrategy() = default;
+
+  arm_gemm::VLType get_vl_type(void) const { return m_vl_type; }
+  unsigned int get_output_rows(void) const { return m_output_rows; }
+  unsigned int get_output_cols(void) const { return m_output_cols; }
+
+  using KernelType = std::function<void(
+    const TInput *const *,  // Input pointers
+    TOutput *const *,  // Output pointers
+    const TWeight *,  // Ravelled weight parameters
+    const TAccum *,  // Bias,
+    unsigned int, unsigned int,  // Number of kernel points, number of output channels
+    TAccum, TAccum  // Activation minimum and maximum
+  )>;
+  virtual KernelType get_kernel(void) const = 0;
+};
+
+template <typename TInput, typename TWeight, typename TOutput>
+class GenericDepthfirstMultiplierKernelStrategy<TInput, TWeight, TOutput, int32_t>
+{
+  const arm_gemm::VLType m_vl_type;
+  const unsigned int m_output_rows, m_output_cols;
+
+  public:
+  GenericDepthfirstMultiplierKernelStrategy(unsigned int output_rows, unsigned int output_cols, arm_gemm::VLType vl_type)
+  : m_vl_type(vl_type), m_output_rows(output_rows), m_output_cols(output_cols)
+  {
+  }
+
+  virtual ~GenericDepthfirstMultiplierKernelStrategy() = default;
+
+  arm_gemm::VLType get_vl_type(void) const { return m_vl_type; }
+  unsigned int get_output_rows(void) const { return m_output_rows; }
+  unsigned int get_output_cols(void) const { return m_output_cols; }
+
+  using KernelType = std::function<void(
+    const TInput *const *,  // Input pointers
+    TOutput *const *,  // Output pointers
+    const TWeight *,  // Ravelled weight parameters
+    const int32_t *,  // Bias,
+    unsigned int, unsigned int,  // Number of kernel points, number of output channels
+    const int32_t *, const int32_t *, const int32_t *,  // Per-channel left-shifts, multipliers, right-shifts (need to account for start channel)
+    const arm_gemm::Requantize32 &
+  )>;
+  virtual KernelType get_kernel(void) const = 0;
+};
+
+template <typename TInput,
+          typename TWeight=TInput,
+          typename TOutput=TInput,
+          typename TAccum=typename DefaultTAccum<TInput>::Type,
+          typename OutputStage=typename DefaultOutputStage<TOutput>::Type>
+class GenericDepthfirstMultiplierStrategy : public DepthwiseDepthfirstStrategyCommon<TInput, TWeight, TOutput, TAccum, OutputStage>
+{
+  using KernelStrategyType = GenericDepthfirstMultiplierKernelStrategy<TInput, TWeight, TOutput, TAccum>;
+  std::unique_ptr<KernelStrategyType> m_kern;
+
+  protected:
+  virtual interleaves::PackingArguments get_packing_args(const DepthwiseArgs &args) const
+  {
+    return interleaves::PackingArguments(
+      args.kernel_rows, args.kernel_cols, sizeof(TWeight),
+      false, sizeof(TAccum),
+      this->get_vl_type(),
+      sizeof(TAccum), 1,
+      [args] (unsigned int pos, unsigned int &x, unsigned int &y) -> bool
+      {
+        if (pos < args.kernel_rows * args.kernel_cols)
+        {
+          y = pos % args.kernel_cols;
+          x = pos / args.kernel_cols;
+          return true;
+        }
+        return false;
+      }
+    );
+  }
+
+  public:
+  GenericDepthfirstMultiplierStrategy(KernelStrategyType *kern, const DepthwiseArgs &args)
+  : DepthwiseDepthfirstStrategyCommon<TInput, TWeight, TOutput, TAccum, OutputStage>(
+      kern->get_output_rows(), kern->get_output_cols(),
+      args.kernel_rows, args.kernel_cols,
+      args.stride_rows, args.stride_cols
+    ),
+    m_kern(kern)
+  {
+  };
+
+  arm_gemm::VLType get_vl_type(void) const override { return m_kern->get_vl_type(); }
+  const typename KernelStrategyType::KernelType get_kernel(void) const { return m_kern->get_kernel(); }
+
+  size_t get_storage_size(const DepthwiseArgs &args) const override
+  {
+    return interleaves::get_storage_size_generic(this->get_packing_args(args), args);
+  }
+
+  void pack_parameters(const DepthwiseArgs &args, void *buffer, const void *biases, const OutputStage &, const void *weights, size_t ld_weight_col, size_t ld_weight_row) const override
+  {
+    interleaves::pack_parameters_generic(
+      this->get_packing_args(args), args,
+      buffer, biases, weights, ld_weight_col, ld_weight_row
+    );
+  }
+};
+
+// Specialise elements of the wrapper based on the type of kernel.
+namespace depthfirst_multiplier {
+
+/* Working space element which contains a pointer for each row of input, a row
+ * of padding, and a space which can be used to construct an NCHW-ordered patch
+ * of input.
+ */
+template <typename T, bool IsGeneric=false, typename OutputStage=Nothing>
+class InputPatchElement
+{
+  public:
+  struct Workspace
+  {
+    constexpr static bool InputPatchIsGeneric = IsGeneric;
+    const T **input_rows;
+    T *input_padding;
+    T *input_patch;
+  };
+
+  static size_t get_element_size(const WorkspaceArgs<IDepthfirstStrategy, OutputStage> &args)
+  {
+    return sizeof_input_rows(args) + sizeof_input_padding(args) + sizeof_input_patch(args);
+  }
+
+  template <class WorkspaceType>
+  static void *initialise(WorkspaceType *ws, void *buffer, const WorkspaceArgs<IDepthfirstStrategy, OutputStage> &args)
+  {
+    auto buffer_bytes = reinterpret_cast<char *>(buffer);
+
+    ws->input_rows = reinterpret_cast<const T **>(buffer_bytes);
+    buffer_bytes += sizeof_input_rows(args);
+
+    ws->input_padding = reinterpret_cast<T*>(buffer_bytes);
+    buffer_bytes += sizeof_input_padding(args);
+
+    ws->input_patch = reinterpret_cast<T*>(buffer_bytes);
+    buffer_bytes += sizeof_input_patch(args);
+
+    // Initialise the padding
+    memset(ws->input_padding,
+           get_input_buffer_fill_value(args.output_stage),
+           sizeof_input_padding(args));
+
+    return buffer_bytes;
+  }
+
+  protected:
+  static size_t sizeof_input_rows(const WorkspaceArgs<IDepthfirstStrategy, OutputStage> &args)
+  {
+    if (IsGeneric)
+    {
+      return sizeof(T *) * args.strategy->get_output_rows() * args.depthwise_args.kernel_rows * args.depthwise_args.kernel_cols;
+    }
+    else
+    {
+      return sizeof(T *) * args.strategy->get_input_rows();
+    }
+  }
+
+  static size_t sizeof_input_padding(const WorkspaceArgs<IDepthfirstStrategy, OutputStage> &args)
+  {
+    // Round-up the number of columns to be a whole number of QUADS
+    auto input_cols = arm_gemm::roundup<size_t>(args.strategy->get_input_cols(), 16 / sizeof(T));
+    return sizeof(T) * input_cols;
+  }
+
+  static size_t sizeof_input_patch(const WorkspaceArgs<IDepthfirstStrategy, OutputStage> &args)
+  {
+    if (IsGeneric)
+    {
+      // Round-up the number of columns to be a whole number of QUADS
+      auto output_cols = arm_gemm::roundup<size_t>(args.strategy->get_output_cols(), 16 / sizeof(T));
+      const auto kernel_points = args.depthwise_args.kernel_rows * args.depthwise_args.kernel_cols;
+      return sizeof(T) * kernel_points * args.strategy->get_output_rows() * output_cols;
+    }
+    else
+    {
+      // Round-up the number of columns to be a whole number of QUADS
+      auto input_cols = arm_gemm::roundup<size_t>(args.strategy->get_input_cols(), 16 / sizeof(T));
+      return sizeof(T) * args.strategy->get_input_rows() * input_cols;
+    }
+  }
+};
+
+template <bool IsGeneric, typename TInput, typename TWeight, typename TOutput, typename TAccum, typename OutputStage>
+struct StrategyType
+{
+  using Type = DepthfirstMultiplierStrategy<TInput, TWeight, TOutput, TAccum>;
+
+  template <typename WorkspaceType>
+  static void execute(
+    const DepthwiseArgs &args, const WorkspaceType *ws, const Type *strat,
+    const OutputStage &, const unsigned int,
+    const void *parameters, const void *
+  )
+  {
+    strat->get_kernel()(
+      ws->input_rows,
+      ws->outptr_array,
+      parameters, args.channel_multiplier,
+      ws->activation_min, ws->activation_max
+    );
+  }
+};
+
+template <typename TInput, typename TWeight, typename TOutput, typename TAccum, typename OutputStage>
+struct StrategyType<true, TInput, TWeight, TOutput, TAccum, OutputStage>
+{
+  using Type = GenericDepthfirstMultiplierStrategy<TInput, TWeight, TOutput, TAccum, OutputStage>;
+
+  template <typename WorkspaceType>
+  static void execute(
+    const DepthwiseArgs &args, const WorkspaceType *ws, const Type *strat,
+    const OutputStage &, const unsigned int start_output_channel,
+    const void *parameters, const void *bias
+  )
+  {
+    strat->get_kernel()(
+      ws->input_rows, ws->outptr_array,
+      reinterpret_cast<const TWeight *>(parameters),
+      bias == nullptr ? nullptr : reinterpret_cast<const TAccum *>(bias) + start_output_channel,
+      strat->get_kernel_rows() * strat->get_kernel_cols(),
+      args.channel_multiplier,
+      ws->activation_min, ws->activation_max
+    );
+  }
+};
+
+template <typename TInput, typename TWeight, typename TOutput>
+struct StrategyType<false, TInput, TWeight, TOutput, int32_t, arm_gemm::Requantize32>
+{
+  using Type = DepthfirstMultiplierStrategy<TInput, TWeight, TOutput, int32_t>;
+
+  template <typename WorkspaceType>
+  static void execute(
+    const DepthwiseArgs &args, const WorkspaceType *ws, const Type *strat,
+    const arm_gemm::Requantize32 &qp, const unsigned int,
+    const void *parameters, const void *
+  )
+  {
+    strat->get_kernel()(
+      ws->input_rows,
+      ws->outptr_array,
+      parameters, args.channel_multiplier,
+      qp
+    );
+  }
+};
+
+template <typename TInput, typename TWeight, typename TOutput>
+struct StrategyType<true, TInput, TWeight, TOutput, int32_t, arm_gemm::Requantize32>
+{
+  using Type = GenericDepthfirstMultiplierStrategy<TInput, TWeight, TOutput, int32_t, arm_gemm::Requantize32>;
+
+  template <typename WorkspaceType>
+  static void execute(
+    const DepthwiseArgs &args, const WorkspaceType *ws, const Type *strat,
+    const arm_gemm::Requantize32 &qp, const unsigned int start_output_channel,
+    const void *parameters, const void *
+  )
+  {
+    auto get_ptr = [start_output_channel] (const int32_t *ptr) -> const int32_t *
+    {
+      return ptr == nullptr ? nullptr : ptr + start_output_channel;
+    };
+
+    strat->get_kernel()(
+      ws->input_rows, ws->outptr_array,
+      reinterpret_cast<const TWeight *>(parameters),
+      get_ptr(qp.bias),
+      strat->get_kernel_rows() * strat->get_kernel_cols(),
+      args.channel_multiplier,
+      get_ptr(qp.per_channel_left_shifts),
+      get_ptr(qp.per_channel_muls),
+      get_ptr(qp.per_channel_right_shifts),
+      qp
+    );
+  }
+};
+
+template <bool IsGeneric> struct PrepareInputSample;
+
+template <> struct PrepareInputSample<false>
+{
+  template <typename WorkspaceType, typename StrategyType, typename T>
+  static void execute(
+    const DepthwiseArgs &, WorkspaceType *ws, const StrategyType *strat,
+    T *base_ptr, size_t ld_row, size_t ld_col,
+    const unsigned int input_pad_top, const unsigned int valid_rows,
+    const unsigned int input_pad_left, const unsigned int valid_cols
+  )
+  {
+    fill_nchw_patch_array(
+      ws->input_rows, ws->input_patch, strat->get_input_rows(), strat->get_input_cols(),
+      base_ptr, ld_row, ld_col,
+      ws->input_padding,
+      input_pad_top, valid_rows,
+      input_pad_left, valid_cols
+    );
+  }
+};
+
+template <> struct PrepareInputSample<true>
+{
+  template <typename WorkspaceType, typename StrategyType, typename T>
+  static void execute(
+    const DepthwiseArgs &args, WorkspaceType *ws, const StrategyType *strat,
+    T *base_ptr, size_t ld_row, size_t ld_col,
+    const unsigned int input_pad_top, const unsigned int valid_rows,
+    const unsigned int input_pad_left, const unsigned int valid_cols
+  )
+  {
+    fill_patch_array_generic_kernel(
+      ws->input_rows, ws->input_patch,
+      strat->get_output_rows(), strat->get_output_cols(),
+      args.kernel_rows, args.kernel_cols,
+      args.stride_rows, args.stride_cols,
+      base_ptr, ld_row, ld_col,
+      ws->input_padding,
+      input_pad_top, valid_rows,
+      input_pad_left, valid_cols
+    );
+  }
+};
+
+}  // namespace depthfirst_multiplier
+
+template <typename TInput,
+          typename TWeight=TInput,
+          typename TOutput=TInput,
+          typename TAccum=typename DefaultTAccum<TInput>::Type,
+          bool is_generic=false,
+          typename OutputStage=typename DefaultOutputStage<TOutput>::Type>
+class DepthwiseDepthfirstMultiplier : public DepthfirstDriver<TInput, TWeight, TOutput>
+{
+  protected:
+  using StratType = typename depthfirst_multiplier::StrategyType<is_generic, TInput, TWeight, TOutput, TAccum, OutputStage>::Type;
+  using WorkspaceManager = Workspace<
+    OutputArrayElement<TOutput>,
+    depthfirst_multiplier::InputPatchElement<TInput, is_generic, OutputStage>,
+    ActivationsElement<TOutput, OutputStage>
+  >;
+  using WorkingSpace = typename WorkspaceManager::WorkspaceType;
+
+  OutputStage m_os;  // Copy of the output parameters
+  const void *m_bias = nullptr;  // Copy of the bias (should we need it)
+
+  public:
+  DepthwiseDepthfirstMultiplier(StratType *const strat, const DepthwiseArgs &args, const OutputStage &os = {})
+  : DepthfirstDriver<TInput, TWeight, TOutput>(strat, args), m_os(os)
+  {
+  }
+
+  DepthwiseDepthfirstMultiplier(DepthwiseDepthfirstMultiplier &) = delete;
+  DepthwiseDepthfirstMultiplier &operator=(DepthwiseDepthfirstMultiplier &) = delete;
 
   size_t get_storage_size(void) const override
   {
-    // TODO What if we insert extra padding? Biases are a different size to the inputs, ...
-    const unsigned int vl = arm_gemm::utils::get_vector_length<TInput>(strategy::vl_type);
-    const auto rounded_channels = this->m_args.input_channels * arm_gemm::roundup(this->m_args.channel_multiplier, vl);
-    return (1 + this->m_args.kernel_rows * this->m_args.kernel_cols) * rounded_channels * sizeof(TWeight);
+    return reinterpret_cast<const StratType *>(this->m_strat.get())
+      ->get_storage_size(this->m_args);
   }
 
-  void pack_parameters(void *_buffer, const void *_biases, const void *_weights, size_t ld_weight_col, size_t ld_weight_row) override
+  void pack_parameters(void *buffer, const void *biases, const void *weights, size_t ld_weight_col, size_t ld_weight_row) override
   {
-    // TODO What if the kernel needs a different packing function?
-
-    // Cast the pointers
-    float *buffer = static_cast<float *>(_buffer);
-    const float *biases = static_cast<const float *>(_biases);
-    const float *const weights = static_cast<const float *>(_weights);
-
-    const unsigned int vl = arm_gemm::utils::get_vector_length<TInput>(strategy::vl_type);
-    ld_weight_col = (ld_weight_col == 0) ? this->m_args.channel_multiplier * this->m_args.input_channels : ld_weight_col;
-    ld_weight_row = (ld_weight_row == 0) ? this->m_args.kernel_cols * ld_weight_col : ld_weight_row;
-
-    for (unsigned int in_c = 0; in_c < this->m_args.input_channels; in_c++)
-    {
-      for (unsigned int n = 0; n < this->m_args.channel_multiplier; n += vl)
-      {
-        const unsigned int out_c = in_c * this->m_args.channel_multiplier + n;
-        const unsigned int todo = std::min(vl, this->m_args.channel_multiplier - n);
-
-        // Copy across the correct amount of bias (or 0)
-        for (unsigned int i = 0; i < todo; i++)
-        {
-          buffer[i] = (biases == nullptr) ? 0 : biases[out_c + i];
-        }
-        buffer += vl;
-
-        // Copy each of the weights in turn
-        auto weights_row = weights + out_c;
-        for (unsigned int i = 0; i < this->m_args.kernel_rows; i++)
-        {
-          auto weights_col = weights_row;
-
-          for (unsigned int j = 0; j < this->m_args.kernel_cols; j++)
-          {
-            for (unsigned int m = 0; m < todo; m++)
-            {
-              buffer[m] = weights_col[m];
-            }
-            buffer += vl;
-
-            weights_col += ld_weight_col;
-          }
-
-          weights_row += ld_weight_row;
-        }
-      }
-    }
+    reinterpret_cast<const StratType *>(this->m_strat.get())
+      ->pack_parameters(this->m_args, buffer, biases, m_os, weights, ld_weight_col, ld_weight_row);
+    m_bias = biases;
+    depthwise_depthfirst::stash_bias(m_os, biases);
   }
 
-  size_t get_working_size(const unsigned int n_threads, const unsigned int n_channels) const override
+  size_t get_working_size_per_thread(const unsigned int n_input_channels) const override
   {
-    const unsigned int n_output_channels = n_channels * this->m_args.channel_multiplier;
-    return n_threads * sizeof_output_buffer(n_output_channels);
+    DepthwiseArgs args(this->m_args);
+    args.input_channels = n_input_channels;
+    return WorkspaceManager::get_sizeof_workspace(WorkspaceArgs<IDepthfirstStrategy, OutputStage>(this->m_strat.get(), args, m_os));
   }
-  
-  using DepthwiseCommon<typename strategy::input_type, typename strategy::weight_type, typename strategy::return_type>::execute;
-  void execute(
-    const unsigned int batches,
-    const unsigned int input_height,
-    const unsigned int input_width,
-    const unsigned int input_channels,
-    const PaddingValues &padding,
-    const void *const _input,
-    const size_t ld_input_col,
-    const size_t ld_input_row,
-    const size_t ld_input_batch,
-    const void *const parameters,
-    const unsigned int output_height,
-    const unsigned int output_width,
-    void *const _output,
-    const size_t ld_output_col,
-    const size_t ld_output_row,
-    const size_t ld_output_batch,
-    void *const _working_space,
-    const unsigned int thread_id,
-    const unsigned int n_threads
+
+  void initialise_working_space(void *buffer, unsigned int n_input_channels) const override
+  {
+    DepthwiseArgs args(this->m_args);
+    args.input_channels = n_input_channels;
+    return WorkspaceManager::initialise(buffer, WorkspaceArgs<IDepthfirstStrategy, OutputStage>(this->m_strat.get(), args, m_os));
+  }
+
+  void compute_tile_padded(
+    unsigned int output_i, unsigned int output_j,
+    unsigned int output_channel_start, unsigned int output_channel_end,
+    const TensorSpec<const TInput *> &input,
+    const TensorSpec<TOutput *> &output,
+    const void *parameters,
+    void *working_space_raw
   ) const override
   {
-    strategy strat(this->m_args.cpu_info);
-#ifdef CYCLE_PROFILING
-    arm_gemm::profiler prof;
-#endif
+    // Get the working space
+    auto ws = reinterpret_cast<WorkingSpace *>(working_space_raw);
 
-    // Compute activation values
-    TAccum activation_min = std::numeric_limits<TAccum>::has_infinity ? -std::numeric_limits<TAccum>::infinity() : std::numeric_limits<TAccum>::min();
-    TAccum activation_max = std::numeric_limits<TAccum>::has_infinity ? std::numeric_limits<TAccum>::infinity() : std::numeric_limits<TAccum>::max();
+    const int ii = static_cast<int>(output_i * this->m_args.stride_rows) - this->m_args.padding.top;
+    const auto input_pad_top = static_cast<unsigned int>(ii < 0 ? -ii : 0);
+    const auto input_i = static_cast<unsigned int>(ii < 0 ? 0 : ii);
 
-    switch (this->m_args.activation.type)
+    const int ij = static_cast<int>(output_j * this->m_args.stride_cols) - this->m_args.padding.left;
+    const auto input_pad_left = static_cast<unsigned int>(ij < 0 ? -ij : 0);
+    const auto input_j = static_cast<unsigned int>(ij < 0 ? 0 : ij);
+
+    // Compute the output pointer array. We'll update this array after every
+    // invocation of the kernel.
+    fill_pointer_array(
+      ws->outptr_array, this->m_strat->get_output_rows(), this->m_strat->get_output_cols(),
+      output.base + output_i*output.ld_row + output_j*output.ld_col + output_channel_start,
+      output.ld_row, output.ld_col,
+      ws->output_buffer,
+      0, this->m_args.output_rows - output_i, // Top padding, # valid rows
+      0, this->m_args.output_cols - output_j  // Left padding, # valid columns
+    );
+
+    // Compute the parameter stride
+    DepthwiseArgs single_iter(this->m_args);
+    single_iter.input_channels = 1;
+    const size_t parameter_stride = reinterpret_cast<const StratType *>(this->m_strat.get())
+      ->get_storage_size(single_iter);
+
+    for (; output_channel_start < output_channel_end;
+         output_channel_start += this->m_args.channel_multiplier)
     {
-      case arm_gemm::Activation::Type::BoundedReLU:
-        activation_max = static_cast<TAccum>(this->m_args.activation.param1);
-        // Fall through
-      case arm_gemm::Activation::Type::ReLU:
-        activation_min = static_cast<TAccum>(0);
-        break;
-      default:
-        break;
-    }
+      // Compute the input pointer array
+      const auto input_channel = output_channel_start / this->m_args.channel_multiplier;
 
-    // Determine what portion of the work to do.
-    const unsigned int n_rows_per_thread = arm_gemm::iceildiv(output_height, n_threads);
-    const int start_out_height = std::min(thread_id * n_rows_per_thread, output_height);
-    const int end_out_height = std::min(start_out_height + n_rows_per_thread, output_height);
+      // Construct the input patch
+      depthfirst_multiplier::PrepareInputSample<is_generic>::execute(
+        this->m_args, ws, this->m_strat.get(),
+        input.base + input_channel + input_i*input.ld_row + input_j*input.ld_col, input.ld_row, input.ld_col,
+        input_pad_top, this->m_args.input_rows - input_i,
+        input_pad_left, this->m_args.input_cols - input_j
+      );
 
-    // Need a stride over blocks of parameters
-    const unsigned int vl = arm_gemm::utils::get_vector_length<TOutput>(strategy::vl_type);
-    const unsigned int param_stride =
-      arm_gemm::roundup(this->m_args.channel_multiplier, vl) *
-      (sizeof(TAccum) + sizeof(TWeight) * strategy::kernel_rows * strategy::kernel_cols);
+      // Execute the kernel
+      depthfirst_multiplier::StrategyType<is_generic, TInput, TWeight, TOutput, TAccum, OutputStage>::execute(
+        this->m_args, ws, reinterpret_cast<const StratType *>(this->m_strat.get()), m_os, output_channel_start,
+        parameters, m_bias
+      );
 
-    // Cast input and output pointers into the right types
-    const TInput *const inptr = static_cast<const TInput *>(_input);
-    TOutput *const outptr = static_cast<TOutput *>(_output);
-
-    // To simplify the kernel, we process padded or non-NCHW-ordered input into
-    // a form which can be consumed by the kernel. This data is stored here and
-    // passed into the kernel as an array of N pointers (one per row of the
-    // input).
-    TInput rearranged_input[strategy::input_rows][strategy::input_col_quads*4];
-    const TInput *inptrs[strategy::input_rows];
-
-    // Create an array for the output pointers
-    TOutput * _outptr_array[strategy::output_rows * strategy::output_cols];
-    TOutput **const outptr_array = _outptr_array;
-
-    // Allocate portions of the working space
-    uint8_t *const working_space = static_cast<uint8_t *>(_working_space) + get_working_size(thread_id, input_channels);
-    TOutput *const output_buffer = reinterpret_cast<TOutput *>(working_space);
-
-    // For each output tile, construct the requisite set of pointers and call
-    // into the kernel.
-    for (unsigned int batch = 0; batch < batches; batch++)
-    {
-      // Get batch pointers
-      const auto inptr_batch = inptr + batch * ld_input_batch;
-      const auto outptr_batch = outptr + batch * ld_output_batch;
-
-      for (int start_out_i = start_out_height;
-           start_out_i < end_out_height;
-           start_out_i += static_cast<int>(strategy::output_rows))
+      // Update the output pointers
+      for (unsigned int n = 0; n < this->m_strat->get_output_rows() * this->m_strat->get_output_cols(); n++)
       {
-        const int end_out_i = start_out_i + strategy::output_rows;
-        const int start_in_i = start_out_i * strategy::stride_rows - padding.top;
-        const int end_in_i = start_in_i + strategy::input_rows;
-
-        // Compute top/bottom padding
-        const auto pad_top = static_cast<unsigned int>(-std::min(start_in_i, 0));
-        const auto pad_bottom = static_cast<unsigned int>(-std::min(static_cast<int>(input_height) - end_in_i, 0));
-        const unsigned int valid_output_rows = std::min(
-          end_out_i - start_out_i,
-          static_cast<int>(output_height) - start_out_i
-        );
-
-        for (int start_out_j = 0; start_out_j < static_cast<int>(output_width);)
-        {
-          const int start_in_j = start_out_j * strategy::stride_cols - this->m_args.padding.left;
-          const int pad_left = -std::min(0, start_in_j);
-
-          const int end_out_j = start_out_j + strategy::output_cols;
-          const int end_in_j = start_in_j + strategy::input_cols;
-
-          const auto pad_right = static_cast<unsigned int>(-std::min(static_cast<int>(input_width) - end_in_j, 0));
-          const unsigned int valid_output_cols = std::min(
-            end_out_j - start_out_j,
-            static_cast<int>(output_width) - start_out_j
-          );
-
-          // Construct the output pointer array.
-          TOutput **outptr_pos = outptr_array;
-          for (auto i = 0u; i < valid_output_rows; i++)
-          {
-            unsigned int j = 0u;
-            TOutput *colptr = outptr_batch + (start_out_i + i) * ld_output_row + start_out_j * ld_output_col;
-            for (; j < valid_output_cols; j++)
-            {
-              *(outptr_pos++) = colptr;
-               colptr += ld_output_col;
-            }
-            for (; j < strategy::output_cols; j++)
-            {
-              *(outptr_pos++) = output_buffer;
-            }
-          }
-          for (auto i = valid_output_rows; i < strategy::output_rows; i++)
-          {
-            for (auto j = 0u; j < strategy::output_cols; j++)
-            {
-              *(outptr_pos++) = output_buffer;
-            }
-          }
-
-          start_out_j += strategy::output_cols;
-
-          const uint8_t *params = static_cast<const uint8_t *>(parameters);
-
-          // Loop over the input channels
-          for (unsigned int in_c = 0; in_c < input_channels; in_c++)
-          {
-            // Construct the input array - first fill with padding values and
-            // then fill in correct values.
-            for (unsigned int i = 0; i < strategy::input_rows; i++)
-            {
-              for (unsigned int j = 0; j < 4 * strategy::input_col_quads; j++)
-              {
-                rearranged_input[i][j] = static_cast<TInput>(0);
-              }
-              inptrs[i] = rearranged_input[i];
-            }
-
-            auto inptr_row = inptr_batch + in_c +
-                             (start_in_i + pad_top) * ld_input_row +
-                             (start_in_j + pad_left) * ld_input_col;
-            if (ld_input_col == 1 && !pad_left &&
-                start_in_j + 4 * strategy::input_col_quads < input_width)
-            {
-              // The input tensor is already in NCHW format, and we're reading
-              // an unpadded section of it - allow the kernel to read it
-              // directly.
-              for (unsigned int i = pad_top; i < strategy::input_rows - pad_bottom; i++)
-              {
-                inptrs[i] = inptr_row;
-                inptr_row += ld_input_row;
-              }
-            }
-            else
-            {
-              // Either the input tensor isn't in NCHW format, or we're reading
-              // a padded section. Copy the relevant portion of the input here
-              // and allow the kernel to read this.
-              for (unsigned int i = pad_top; i < strategy::input_rows - pad_bottom; i++)
-              {
-                auto inptr_col = inptr_row;
-                for (unsigned int j = pad_left; j < strategy::input_cols - pad_right; j++)
-                {
-                  rearranged_input[i][j] = *inptr_col;
-                  inptr_col += ld_input_col;
-                }
-                inptr_row += ld_input_row;
-              }
-            }
-
-            {
-#ifdef CYCLE_PROFILING
-              auto p = prof.ScopedProfiler(PROFILE_KERNEL, (unsigned long)(strategy::output_rows * strategy::output_cols * this->m_args.channel_multiplier * strategy::kernel_rows * strategy::kernel_cols));
-#endif
-              strat.kernel(
-                inptrs, outptr_array, params,
-                this->m_args.channel_multiplier,
-                activation_min, activation_max
-              );
-            }
-
-            // Progress the output pointers
-            TOutput **outptr_pos = outptr_array;
-            for (auto i = 0u; i < strategy::output_rows * strategy::output_cols; i++)
-            {
-              outptr_pos[i] += this->m_args.channel_multiplier;
-            }
-
-            // Progress the pointer into the parameters
-            params += param_stride;
-          }
-        }
+        ws->outptr_array[n] += this->m_args.channel_multiplier;
       }
+
+      // Progress the parameters
+      parameters = reinterpret_cast<const char *>(parameters) + parameter_stride;
     }
   }
 };
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/depthwise_fp16.cpp b/src/core/NEON/kernels/arm_conv/depthwise/depthwise_fp16.cpp
index 934272a..6b100d9 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/depthwise_fp16.cpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/depthwise_fp16.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -28,7 +28,6 @@
 #include "depthwise_depthfirst.hpp"
 #include "depthwise_depthfirst_generic.hpp"
 #include "depthwise_depthfirst_multiplier.hpp"
-#include "depthwise_depthfirst_generic_multiplier.hpp"
 
 #include "depthwise_implementation_constraints.hpp"
 
@@ -43,6 +42,7 @@
 #include "kernels/sve_fp16_nhwc_3x3_s2_output2x2_mla_depthfirst.hpp"
 #include "kernels/sve_fp16_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp"
 #endif  // defined(ARM_COMPUTE_ENABLE_SVE)
+#if defined(ENABLE_FP16_KERNELS) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
 #include "kernels/a64_fp16_nhwc_3x3_s1_output4x4_mla_depthfirst.hpp"
 #include "kernels/a64_fp16_nhwc_3x3_s1_output3x3_mla_depthfirst.hpp"
 #include "kernels/a64_fp16_nhwc_3x3_s1_output2x2_mla_depthfirst.hpp"
@@ -50,6 +50,7 @@
 #include "kernels/a64_fp16_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp"
 #include "kernels/a64_fp16_nhwc_generic_output9_mla_depthfirst.hpp"
 #include "kernels/a64_fp16_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst.hpp"
+#endif  // defined(ENABLE_FP16_KERNELS) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
 #endif  // defined(__aarch64__)
 
 namespace arm_conv {
@@ -70,15 +71,11 @@
   }
 
 #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
+  unsigned int not_preferred(const DepthwiseArgs &, const Nothing &) __attribute__ ((unused));
   unsigned int not_preferred(const DepthwiseArgs &, const Nothing &)
   {
     return std::numeric_limits<unsigned int>::max();
   }
-
-  unsigned int not_preferred_if_no_multiplier(const DepthwiseArgs &args, const Nothing &)
-  {
-    return args.channel_multiplier > 1 ? 0 : std::numeric_limits<unsigned int>::max();
-  }
 #endif  // defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
 }
 
@@ -94,7 +91,7 @@
     cycle_estimate<sve_fp16_nhwc_3x3_s1_output4x4_mla_depthfirst>,
     [] (const DepthwiseArgs &args, const Nothing &) -> DepthwiseCommon<__fp16, __fp16, __fp16> * {
       auto strat = new sve_fp16_nhwc_3x3_s1_output4x4_mla_depthfirst(args.cpu_info);
-      return new DepthwiseDepthfirst<__fp16, __fp16, __fp16, __fp16>(strat, args);
+      return new DepthwiseDepthfirst<__fp16>(strat, args);
     },
   },
   {
@@ -106,7 +103,7 @@
     cycle_estimate<sve_fp16_nhwc_3x3_s1_output3x3_mla_depthfirst>,
     [] (const DepthwiseArgs &args, const Nothing &) -> DepthwiseCommon<__fp16, __fp16, __fp16> * {
       auto strat = new sve_fp16_nhwc_3x3_s1_output3x3_mla_depthfirst(args.cpu_info);
-      return new DepthwiseDepthfirst<__fp16, __fp16, __fp16, __fp16>(strat, args);
+      return new DepthwiseDepthfirst<__fp16>(strat, args);
     },
   },
   {
@@ -118,7 +115,7 @@
     cycle_estimate<sve_fp16_nhwc_3x3_s1_output2x2_mla_depthfirst>,
     [] (const DepthwiseArgs &args, const Nothing &) -> DepthwiseCommon<__fp16, __fp16, __fp16> * {
       auto strat = new sve_fp16_nhwc_3x3_s1_output2x2_mla_depthfirst(args.cpu_info);
-      return new DepthwiseDepthfirst<__fp16, __fp16, __fp16, __fp16>(strat, args);
+      return new DepthwiseDepthfirst<__fp16>(strat, args);
     },
   },
   {
@@ -130,7 +127,7 @@
     cycle_estimate<sve_fp16_nhwc_3x3_s2_output2x2_mla_depthfirst>,
     [] (const DepthwiseArgs &args, const Nothing &) -> DepthwiseCommon<__fp16, __fp16, __fp16> * {
       auto strat = new sve_fp16_nhwc_3x3_s2_output2x2_mla_depthfirst(args.cpu_info);
-      return new DepthwiseDepthfirst<__fp16, __fp16, __fp16, __fp16>(strat, args);
+      return new DepthwiseDepthfirst<__fp16>(strat, args);
     },
   },
   {
@@ -142,11 +139,11 @@
     cycle_estimate<sve_fp16_nhwc_5x5_s1_output2x2_mla_depthfirst>,
     [] (const DepthwiseArgs &args, const Nothing &) -> DepthwiseCommon<__fp16, __fp16, __fp16> * {
       auto strat = new sve_fp16_nhwc_5x5_s1_output2x2_mla_depthfirst(args.cpu_info);
-      return new DepthwiseDepthfirst<__fp16, __fp16, __fp16, __fp16>(strat, args);
+      return new DepthwiseDepthfirst<__fp16>(strat, args);
     },
   },
 #endif  // defined(ARM_COMPUTE_ENABLE_SVE)
-#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
+#if defined(ENABLE_FP16_KERNELS) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
   {
     DepthwiseMethod::DEPTHFIRST,
     "a64_fp16_nhwc_3x3_s1_output4x4_mla_depthfirst",
@@ -156,7 +153,7 @@
     cycle_estimate<a64_fp16_nhwc_3x3_s1_output4x4_mla_depthfirst>,
     [] (const DepthwiseArgs &args, const Nothing &) -> DepthwiseCommon<__fp16, __fp16, __fp16> * {
       auto strat = new a64_fp16_nhwc_3x3_s1_output4x4_mla_depthfirst(args.cpu_info);
-      return new DepthwiseDepthfirst<__fp16, __fp16, __fp16, __fp16>(strat, args);
+      return new DepthwiseDepthfirst<__fp16>(strat, args);
     },
   },
   {
@@ -168,7 +165,7 @@
     cycle_estimate<a64_fp16_nhwc_3x3_s1_output3x3_mla_depthfirst>,
     [] (const DepthwiseArgs &args, const Nothing &) -> DepthwiseCommon<__fp16, __fp16, __fp16> * {
       auto strat = new a64_fp16_nhwc_3x3_s1_output3x3_mla_depthfirst(args.cpu_info);
-      return new DepthwiseDepthfirst<__fp16, __fp16, __fp16, __fp16>(strat, args);
+      return new DepthwiseDepthfirst<__fp16>(strat, args);
     },
   },
   {
@@ -180,7 +177,7 @@
     cycle_estimate<a64_fp16_nhwc_3x3_s1_output2x2_mla_depthfirst>,
     [] (const DepthwiseArgs &args, const Nothing &) -> DepthwiseCommon<__fp16, __fp16, __fp16> * {
       auto strat = new a64_fp16_nhwc_3x3_s1_output2x2_mla_depthfirst(args.cpu_info);
-      return new DepthwiseDepthfirst<__fp16, __fp16, __fp16, __fp16>(strat, args);
+      return new DepthwiseDepthfirst<__fp16>(strat, args);
     },
   },
   {
@@ -192,7 +189,7 @@
     cycle_estimate<a64_fp16_nhwc_3x3_s2_output2x2_mla_depthfirst>,
     [] (const DepthwiseArgs &args, const Nothing &) -> DepthwiseCommon<__fp16, __fp16, __fp16> * {
       auto strat = new a64_fp16_nhwc_3x3_s2_output2x2_mla_depthfirst(args.cpu_info);
-      return new DepthwiseDepthfirst<__fp16, __fp16, __fp16, __fp16>(strat, args);
+      return new DepthwiseDepthfirst<__fp16>(strat, args);
     },
   },
   {
@@ -204,7 +201,7 @@
     cycle_estimate<a64_fp16_nhwc_5x5_s1_output2x2_mla_depthfirst>,
     [] (const DepthwiseArgs &args, const Nothing &) -> DepthwiseCommon<__fp16, __fp16, __fp16> * {
       auto strat = new a64_fp16_nhwc_5x5_s1_output2x2_mla_depthfirst(args.cpu_info);
-      return new DepthwiseDepthfirst<__fp16, __fp16, __fp16, __fp16>(strat, args);
+      return new DepthwiseDepthfirst<__fp16>(strat, args);
     },
   },
   {
@@ -213,19 +210,23 @@
     constraint(has_no_channel_multiplier, cpu_has_fp16),
     not_preferred,
     [] (const DepthwiseArgs &args, const Nothing &) -> DepthwiseCommon<__fp16, __fp16, __fp16> * {
-      return new DepthwiseDepthfirstGeneric<a64_fp16_nhwc_generic_output9_mla_depthfirst, 3, 3>(args);
+      auto kern = new a64_fp16_nhwc_generic_output9_mla_depthfirst(args.cpu_info);
+      auto strat = new GenericDepthfirstStrategy<__fp16>(kern, 3, 3, args);
+      return new DepthwiseDepthfirstGeneric<__fp16>(strat, args);
     },
   },
   {
     DepthwiseMethod::DEPTHFIRST,
     "a64_fp16_nhwc_generic_with_multiplier_output2x8_mla_depthfirst",
-    constraint(cpu_has_fp16),
-    not_preferred_if_no_multiplier,
+    constraint(cpu_has_fp16, has_channel_multiplier),
+    nullptr,
     [] (const DepthwiseArgs &args, const Nothing &) -> DepthwiseCommon<__fp16, __fp16, __fp16> * {
-      return new DepthwiseDepthfirstGenericWithMultiplier<a64_fp16_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst>(args);
+      auto kern = new a64_fp16_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst(args.cpu_info);
+      auto strat = new GenericDepthfirstMultiplierStrategy<__fp16>(kern, args);
+      return new DepthwiseDepthfirstMultiplier<__fp16, __fp16, __fp16, __fp16, true>(strat, args);
     },
   },
-#endif  // defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
+#endif  // defined(ENABLE_FP16_KERNELS) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
 #endif  // defined(__aarch64__)
   { DepthwiseMethod::DEFAULT, "", nullptr, nullptr, nullptr },  // End of list
 };
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/depthwise_fp32.cpp b/src/core/NEON/kernels/arm_conv/depthwise/depthwise_fp32.cpp
index 5107dda..643cf1d 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/depthwise_fp32.cpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/depthwise_fp32.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -28,7 +28,7 @@
 #include "depthwise_depthfirst.hpp"
 #include "depthwise_depthfirst_generic.hpp"
 #include "depthwise_depthfirst_multiplier.hpp"
-#include "depthwise_depthfirst_generic_multiplier.hpp"
+#include "depthwise_planar.hpp"
 
 #include "depthwise_implementation_constraints.hpp"
 
@@ -78,9 +78,10 @@
     return std::numeric_limits<unsigned int>::max();
   }
 
-  unsigned int not_preferred_if_no_multiplier(const DepthwiseArgs &args, const Nothing &)
+  bool fast_mode_enabled(const DepthwiseArgs &args, const void *) __attribute__ ((unused));
+  bool fast_mode_enabled(const DepthwiseArgs &args, const void *)
   {
-    return args.channel_multiplier > 1 ? 0 : std::numeric_limits<unsigned int>::max();
+    return args.fast_mode;
   }
 #endif // defined(__aarch64__)
 }
@@ -97,7 +98,7 @@
     cycle_estimate<sve_fp32_nhwc_3x3_s1_output4x4_mla_depthfirst>,
     [] (const DepthwiseArgs &args, const Nothing &) -> DepthwiseCommon<float, float, float> * {
       auto strat = new sve_fp32_nhwc_3x3_s1_output4x4_mla_depthfirst(args.cpu_info);
-      return new DepthwiseDepthfirst<float, float, float, float>(strat, args);
+      return new DepthwiseDepthfirst<float>(strat, args);
     },
   },
   {
@@ -109,7 +110,7 @@
     cycle_estimate<sve_fp32_nhwc_3x3_s1_output3x3_mla_depthfirst>,
     [] (const DepthwiseArgs &args, const Nothing &) -> DepthwiseCommon<float, float, float> * {
       auto strat = new sve_fp32_nhwc_3x3_s1_output3x3_mla_depthfirst(args.cpu_info);
-      return new DepthwiseDepthfirst<float, float, float, float>(strat, args);
+      return new DepthwiseDepthfirst<float>(strat, args);
     },
   },
   {
@@ -121,7 +122,7 @@
     cycle_estimate<sve_fp32_nhwc_3x3_s1_output2x2_mla_depthfirst>,
     [] (const DepthwiseArgs &args, const Nothing &) -> DepthwiseCommon<float, float, float> * {
       auto strat = new sve_fp32_nhwc_3x3_s1_output2x2_mla_depthfirst(args.cpu_info);
-      return new DepthwiseDepthfirst<float, float, float, float>(strat, args);
+      return new DepthwiseDepthfirst<float>(strat, args);
     },
   },
   {
@@ -133,7 +134,7 @@
     cycle_estimate<sve_fp32_nhwc_3x3_s2_output2x2_mla_depthfirst>,
     [] (const DepthwiseArgs &args, const Nothing &) -> DepthwiseCommon<float, float, float> * {
       auto strat = new sve_fp32_nhwc_3x3_s2_output2x2_mla_depthfirst(args.cpu_info);
-      return new DepthwiseDepthfirst<float, float, float, float>(strat, args);
+      return new DepthwiseDepthfirst<float>(strat, args);
     },
   },
   {
@@ -145,7 +146,7 @@
     cycle_estimate<sve_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst>,
     [] (const DepthwiseArgs &args, const Nothing &) -> DepthwiseCommon<float, float, float> * {
       auto strat = new sve_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst(args.cpu_info);
-      return new DepthwiseDepthfirst<float, float, float, float>(strat, args);
+      return new DepthwiseDepthfirst<float>(strat, args);
     },
   },
   {
@@ -154,36 +155,42 @@
     constraint(has_no_channel_multiplier, cpu_has_sve),
     not_preferred,
     [] (const DepthwiseArgs &args, const Nothing &) -> DepthwiseCommon<float, float, float> * {
-      return new DepthwiseDepthfirstGeneric<sve_fp32_nhwc_generic_output9_mla_depthfirst, 3, 3>(args);
+      auto kern = new sve_fp32_nhwc_generic_output9_mla_depthfirst(args.cpu_info);
+      auto strat = new GenericDepthfirstStrategy<float>(kern, 3, 3, args);
+      return new DepthwiseDepthfirstGeneric<float>(strat, args);
     },
   },
   {
     DepthwiseMethod::DEPTHFIRST,
     "sve_fp32_nhwc_3x3_s2_with_multiplier_output3x3_mla_depthfirst",
     constraint(is_supported<sve_fp32_packed_to_nhwc_3x3_s2_with_multiplier_output3x3_mla_depthfirst>,
-               cpu_has_sve),
-    not_preferred_if_no_multiplier,
+               cpu_has_sve, has_channel_multiplier),
+    nullptr,
     [] (const DepthwiseArgs &args, const Nothing &) -> DepthwiseCommon<float, float, float> * {
-      return new DepthwiseDepthfirstWithMultiplier<sve_fp32_packed_to_nhwc_3x3_s2_with_multiplier_output3x3_mla_depthfirst>(args);
+      auto strat = new sve_fp32_packed_to_nhwc_3x3_s2_with_multiplier_output3x3_mla_depthfirst(args.cpu_info);
+      return new DepthwiseDepthfirstMultiplier<float>(strat, args);
     },
   },
   {
     DepthwiseMethod::DEPTHFIRST,
     "sve_fp32_nhwc_5x5_s1_with_multiplier_output2x4_mla_depthfirst",
     constraint(is_supported<sve_fp32_packed_to_nhwc_5x5_s1_with_multiplier_output2x4_mla_depthfirst>,
-               cpu_has_sve),
-    not_preferred_if_no_multiplier,
+               cpu_has_sve, has_channel_multiplier),
+    nullptr,
     [] (const DepthwiseArgs &args, const Nothing &) -> DepthwiseCommon<float, float, float> * {
-      return new DepthwiseDepthfirstWithMultiplier<sve_fp32_packed_to_nhwc_5x5_s1_with_multiplier_output2x4_mla_depthfirst>(args);
+      auto strat = new sve_fp32_packed_to_nhwc_5x5_s1_with_multiplier_output2x4_mla_depthfirst(args.cpu_info);
+      return new DepthwiseDepthfirstMultiplier<float>(strat, args);
     },
   },
   {
     DepthwiseMethod::DEPTHFIRST,
     "sve_fp32_nhwc_generic_with_multiplier_output2x8_mla_depthfirst",
-    constraint(cpu_has_sve),
-    not_preferred_if_no_multiplier,
+    constraint(cpu_has_sve, has_channel_multiplier),
+    nullptr,
     [] (const DepthwiseArgs &args, const Nothing &) -> DepthwiseCommon<float, float, float> * {
-      return new DepthwiseDepthfirstGenericWithMultiplier<sve_fp32_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst>(args);
+      auto kern = new sve_fp32_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst(args.cpu_info);
+      auto strat = new GenericDepthfirstMultiplierStrategy<float>(kern, args);
+      return new DepthwiseDepthfirstMultiplier<float, float, float, float, true>(strat, args);
     },
   },
 #endif  // defined(ARM_COMPUTE_ENABLE_SVE)
@@ -195,7 +202,7 @@
     cycle_estimate<a64_fp32_nhwc_3x3_s1_output4x4_mla_depthfirst>,
     [] (const DepthwiseArgs &args, const Nothing &) -> DepthwiseCommon<float, float, float> * {
       auto strat = new a64_fp32_nhwc_3x3_s1_output4x4_mla_depthfirst(args.cpu_info);
-      return new DepthwiseDepthfirst<float, float, float, float>(strat, args);
+      return new DepthwiseDepthfirst<float>(strat, args);
     },
   },
   {
@@ -206,7 +213,7 @@
     cycle_estimate<a64_fp32_nhwc_3x3_s1_output3x3_mla_depthfirst>,
     [] (const DepthwiseArgs &args, const Nothing &) -> DepthwiseCommon<float, float, float> * {
       auto strat = new a64_fp32_nhwc_3x3_s1_output3x3_mla_depthfirst(args.cpu_info);
-      return new DepthwiseDepthfirst<float, float, float, float>(strat, args);
+      return new DepthwiseDepthfirst<float>(strat, args);
     },
   },
   {
@@ -217,7 +224,7 @@
     cycle_estimate<a64_fp32_nhwc_3x3_s1_output2x2_mla_depthfirst>,
     [] (const DepthwiseArgs &args, const Nothing &) -> DepthwiseCommon<float, float, float> * {
       auto strat = new a64_fp32_nhwc_3x3_s1_output2x2_mla_depthfirst(args.cpu_info);
-      return new DepthwiseDepthfirst<float, float, float, float>(strat, args);
+      return new DepthwiseDepthfirst<float>(strat, args);
     },
   },
   {
@@ -228,7 +235,7 @@
     cycle_estimate<a64_fp32_nhwc_3x3_s2_output2x2_mla_depthfirst>,
     [] (const DepthwiseArgs &args, const Nothing &) -> DepthwiseCommon<float, float, float> * {
       auto strat = new a64_fp32_nhwc_3x3_s2_output2x2_mla_depthfirst(args.cpu_info);
-      return new DepthwiseDepthfirst<float, float, float, float>(strat, args);
+      return new DepthwiseDepthfirst<float>(strat, args);
     },
   },
   {
@@ -239,7 +246,7 @@
     cycle_estimate<a64_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst>,
     [] (const DepthwiseArgs &args, const Nothing &) -> DepthwiseCommon<float, float, float> * {
       auto strat = new a64_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst(args.cpu_info);
-      return new DepthwiseDepthfirst<float, float, float, float>(strat, args);
+      return new DepthwiseDepthfirst<float>(strat, args);
     },
   },
   {
@@ -248,34 +255,42 @@
     constraint(has_no_channel_multiplier),
     not_preferred,
     [] (const DepthwiseArgs &args, const Nothing &) -> DepthwiseCommon<float, float, float> * {
-      return new DepthwiseDepthfirstGeneric<a64_fp32_nhwc_generic_output9_mla_depthfirst, 3, 3>(args);
+      auto kern = new a64_fp32_nhwc_generic_output9_mla_depthfirst(args.cpu_info);
+      auto strat = new GenericDepthfirstStrategy<float>(kern, 3, 3, args);
+      return new DepthwiseDepthfirstGeneric<float>(strat, args);
     },
   },
   {
     DepthwiseMethod::DEPTHFIRST,
     "a64_fp32_nhwc_3x3_s2_with_multiplier_output3x3_mla_depthfirst",
-    constraint(is_supported<a64_fp32_packed_to_nhwc_3x3_s2_with_multiplier_output3x3_mla_depthfirst>),
-    not_preferred_if_no_multiplier,
+    constraint(is_supported<a64_fp32_packed_to_nhwc_3x3_s2_with_multiplier_output3x3_mla_depthfirst>,
+               has_channel_multiplier),
+    nullptr,
     [] (const DepthwiseArgs &args, const Nothing &) -> DepthwiseCommon<float, float, float> * {
-      return new DepthwiseDepthfirstWithMultiplier<a64_fp32_packed_to_nhwc_3x3_s2_with_multiplier_output3x3_mla_depthfirst>(args);
+      auto strat = new a64_fp32_packed_to_nhwc_3x3_s2_with_multiplier_output3x3_mla_depthfirst(args.cpu_info);
+      return new DepthwiseDepthfirstMultiplier<float>(strat, args);
     },
   },
   {
     DepthwiseMethod::DEPTHFIRST,
     "a64_fp32_nhwc_5x5_s1_with_multiplier_output2x4_mla_depthfirst",
-    constraint(is_supported<a64_fp32_packed_to_nhwc_5x5_s1_with_multiplier_output2x4_mla_depthfirst>),
-    not_preferred_if_no_multiplier,
+    constraint(is_supported<a64_fp32_packed_to_nhwc_5x5_s1_with_multiplier_output2x4_mla_depthfirst>,
+               has_channel_multiplier),
+    nullptr,
     [] (const DepthwiseArgs &args, const Nothing &) -> DepthwiseCommon<float, float, float> * {
-      return new DepthwiseDepthfirstWithMultiplier<a64_fp32_packed_to_nhwc_5x5_s1_with_multiplier_output2x4_mla_depthfirst>(args);
+      auto strat = new a64_fp32_packed_to_nhwc_5x5_s1_with_multiplier_output2x4_mla_depthfirst(args.cpu_info);
+      return new DepthwiseDepthfirstMultiplier<float>(strat, args);
     },
   },
   {
     DepthwiseMethod::DEPTHFIRST,
     "a64_fp32_nhwc_generic_with_multiplier_output2x8_mla_depthfirst",
+    constraint(has_channel_multiplier),
     nullptr,
-    not_preferred_if_no_multiplier,
     [] (const DepthwiseArgs &args, const Nothing &) -> DepthwiseCommon<float, float, float> * {
-      return new DepthwiseDepthfirstGenericWithMultiplier<a64_fp32_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst>(args);
+      auto kern = new a64_fp32_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst(args.cpu_info);
+      auto strat = new GenericDepthfirstMultiplierStrategy<float>(kern, args);
+      return new DepthwiseDepthfirstMultiplier<float, float, float, float, true>(strat, args);
     },
   },
 #endif  // defined(__aarch64__)
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/depthwise_implementation.hpp b/src/core/NEON/kernels/arm_conv/depthwise/depthwise_implementation.hpp
index ea41529..0665fa3 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/depthwise_implementation.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/depthwise_implementation.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -24,7 +24,7 @@
 
 #pragma once
 
-#include "src/core/NEON/kernels/assembly/depthwise.hpp"
+#include "depthwise.hpp"
 
 #include <cstddef>
 #include <functional>
@@ -136,14 +136,7 @@
 {
   const DepthwiseImplementation<TInput, TWeight, TOutput, OutputStage> *impl = nullptr;
   const bool success = find_implementation<TInput, TWeight, TOutput, OutputStage>(args, os, impl);
-
-  if(success)
-  {
-        auto i =  impl->get_instance(args, os);
-        i->set_name(impl->name);
-        return UniqueDepthwiseCommon<TInput, TWeight, TOutput>(i);
-  }
-  return nullptr;
+  return UniqueDepthwiseCommon<TInput, TWeight, TOutput>(success ? impl->get_instance(args, os) : nullptr);
 }
 
 }  // namespace depthwise
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/depthwise_implementation_constraints.hpp b/src/core/NEON/kernels/arm_conv/depthwise/depthwise_implementation_constraints.hpp
index 4198727..78b6aec 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/depthwise_implementation_constraints.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/depthwise_implementation_constraints.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -109,6 +109,12 @@
   return args.channel_multiplier == 1;
 }
 
+bool has_channel_multiplier(const DepthwiseArgs &args, const void *) __attribute__ ((unused));
+bool has_channel_multiplier(const DepthwiseArgs &args, const void *)
+{
+  return args.channel_multiplier > 1;
+}
+
 bool qp_has_no_left_shift(const DepthwiseArgs &args, const void *_qp) __attribute__ ((unused));
 bool qp_has_no_left_shift(const DepthwiseArgs &, const void *_qp)
 {
@@ -118,6 +124,21 @@
     (qp->per_layer_left_shift == 0);
 }
 
+bool qp_zero_a_offset(const DepthwiseArgs &args, const void *_qp) __attribute__ ((unused));
+bool qp_zero_a_offset(const DepthwiseArgs &, const void *_qp)
+{
+  const auto qp = static_cast<const arm_gemm::Requantize32 *>(_qp);
+  return qp->a_offset == 0;
+}
+
+template <typename T> bool qp_skip_clamp(const DepthwiseArgs &args, const void *_qp) __attribute__ ((unused));
+template <typename T> bool qp_skip_clamp(const DepthwiseArgs &, const void *_qp)
+{
+  const auto qp = static_cast<const arm_gemm::Requantize32 *>(_qp);
+  return (qp->minval == std::numeric_limits<T>::min() &&
+          qp->maxval == std::numeric_limits<T>::max());
+}
+
 }  // namespace
 }  // namespace depthwise
 }  // namespace arm_conv
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/depthwise_planar.hpp b/src/core/NEON/kernels/arm_conv/depthwise/depthwise_planar.hpp
new file mode 100644
index 0000000..ff5098d
--- /dev/null
+++ b/src/core/NEON/kernels/arm_conv/depthwise/depthwise_planar.hpp
@@ -0,0 +1,409 @@
+/*
+ * Copyright (c) 2022 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "depthfirst_driver.hpp"
+#include "interleaves/generic.hpp"
+
+namespace arm_conv {
+namespace depthwise {
+
+template <typename OutputStage>
+class IPlanarStrategy
+{
+  public:
+  virtual ~IPlanarStrategy() = default;
+  virtual unsigned int get_output_rows(void) const = 0;
+  virtual arm_gemm::VLType get_vl_type(void) const = 0;
+
+  virtual size_t get_storage_size(const DepthwiseArgs &) const = 0;
+  virtual void pack_parameters(
+    const DepthwiseArgs &args, void *buffer,
+    const void *biases, const OutputStage &,
+    const void *weights, size_t ld_weight_col, size_t ld_weight_row
+  ) const = 0;
+};
+
+
+template <typename TInput, typename TWeight, typename TOutput, typename TAccum,
+          typename OutputStage>
+struct PlanarKernelType;
+
+template <typename TInput, typename TWeight, typename TOutput, typename TAccum>
+struct PlanarKernelType<TInput, TWeight, TOutput, TAccum, Nothing>
+{
+  using Type = std::function<void(
+    const TInput *, size_t ld_in_row, size_t ld_in_col, size_t ld_in_vl,
+    unsigned int pad_top, unsigned int valid_input_rows,
+    unsigned int pad_left, unsigned int valid_input_cols,
+    const TWeight *, const TAccum *,
+    TOutput **, const size_t *, const size_t *, unsigned int output_cols,
+    unsigned int start_channels, unsigned int valid_channels,
+    TAccum act_min, TAccum act_max
+  )>;
+
+  template <typename WorkspaceType>
+  static inline void execute(
+    const Type fn,
+    const TInput *inptr, size_t ld_in_row, size_t ld_in_col, size_t ld_in_vl,
+    unsigned int pad_top, unsigned int valid_input_rows,
+    unsigned int pad_left, unsigned int valid_input_cols,
+    const TWeight *weights, const TAccum *bias,
+    TOutput **outptrs, const size_t *outlds, const size_t *outvllds, unsigned int output_cols,
+    unsigned int start_channel, unsigned int valid_channels,
+    const Nothing &, const WorkspaceType *ws
+  )
+  {
+    fn(
+      inptr, ld_in_row, ld_in_col, ld_in_vl,
+      pad_top, valid_input_rows,
+      pad_left, valid_input_cols,
+      weights, bias,
+      outptrs, outlds, outvllds, output_cols,
+      start_channel, valid_channels,
+      ws->activation_min, ws->activation_max
+    );
+  }
+};
+
+template <typename TInput, typename TWeight, typename TOutput>
+struct PlanarKernelType<TInput, TWeight, TOutput, int32_t, arm_gemm::Requantize32>
+{
+  using Type = std::function<void(
+    const TInput *, size_t ld_in_row, size_t ld_in_col, size_t ld_in_vl,
+    unsigned int pad_top, unsigned int valid_input_rows,
+    unsigned int pad_left, unsigned int valid_input_cols,
+    const TWeight *,
+    TOutput **, const size_t *, const size_t *, unsigned int output_cols,
+    unsigned int start_channel, unsigned int valid_channels,
+    const arm_gemm::Requantize32 &
+  )>;
+
+  template <typename WorkspaceType>
+  static inline void execute(
+    const Type fn,
+    const TInput *inptr, size_t ld_in_row, size_t ld_in_col, size_t ld_in_vl,
+    unsigned int pad_top, unsigned int valid_input_rows,
+    unsigned int pad_left, unsigned int valid_input_cols,
+    const TWeight *weights, const int32_t *,
+    TOutput **outptrs, const size_t *outlds, const size_t *outldvls, unsigned int output_cols,
+    unsigned int first_channel, unsigned int valid_channels,
+    const arm_gemm::Requantize32 &qp, const WorkspaceType *
+  )
+  {
+    fn(
+      inptr, ld_in_row, ld_in_col, ld_in_vl,
+      pad_top, valid_input_rows,
+      pad_left, valid_input_cols,
+      weights,
+      outptrs, outlds, outldvls, output_cols,
+      first_channel, valid_channels,
+      qp
+    );
+  }
+};
+
+
+template <typename TInput, typename TWeight=TInput, typename TOutput=TInput,
+          typename TAccum=typename DefaultTAccum<TOutput>::Type,
+          typename OutputStage=typename DefaultOutputStage<TOutput>::Type>
+class PlanarStrategy : public IPlanarStrategy<OutputStage>
+{
+  unsigned int m_kernel_rows, m_kernel_cols;
+  unsigned int m_stride_rows, m_stride_cols;
+  unsigned int m_output_rows;
+  arm_gemm::VLType m_vl_type;
+
+  protected:
+  virtual bool get_kernel_packing_point(const unsigned int index, unsigned int &x, unsigned int &y) const
+  {
+    // Get the kernel point to pack at the given index; return false to
+    // indicate that this index (and all greater indices) is out of range.
+    if (m_kernel_rows * m_kernel_cols <= index)
+      return false;
+
+    y = index % m_kernel_cols;
+    x = index / m_kernel_cols;
+    return true;
+  }
+
+  virtual interleaves::PackingArguments get_kernel_packing_arguments(void) const
+  {
+    return interleaves::PackingArguments(
+      m_kernel_rows, m_kernel_cols, sizeof(TWeight),
+      false, sizeof(TAccum),  // Don't pack the bias
+      m_vl_type, sizeof(TAccum), 1,  // Accumulator depth of 1 TODO
+      [this] (unsigned int idx, unsigned int &x, unsigned int &y) -> bool
+      { return this->get_kernel_packing_point(idx, x, y); }
+    );
+  }
+
+  public:
+  PlanarStrategy(
+    unsigned int kernel_rows, unsigned int kernel_cols,
+    unsigned int stride_rows, unsigned int stride_cols,
+    unsigned int output_rows,
+    arm_gemm::VLType vl_type
+  ) : m_kernel_rows(kernel_rows), m_kernel_cols(kernel_cols),
+      m_stride_rows(stride_rows), m_stride_cols(stride_cols),
+      m_output_rows(output_rows), m_vl_type(vl_type)
+  {
+  }
+
+  unsigned int get_output_rows(void) const override { return m_output_rows; }
+  arm_gemm::VLType get_vl_type(void) const override { return m_vl_type; }
+
+  size_t get_storage_size(const DepthwiseArgs &args) const override
+  {
+    return interleaves::get_storage_size_generic(this->get_kernel_packing_arguments(), args);
+  }
+
+  void pack_parameters(
+    const DepthwiseArgs &args, void *buffer,
+    const void *biases, const OutputStage &,
+    const void *weights, size_t ld_weight_col, size_t ld_weight_row
+  ) const override
+  {
+    interleaves::pack_parameters_generic(
+      this->get_kernel_packing_arguments(), args,
+      buffer, biases, weights, ld_weight_col, ld_weight_row
+    );
+  }
+
+  using KernelType = typename PlanarKernelType<TInput, TWeight, TOutput, TAccum, OutputStage>::Type;
+  virtual KernelType get_kernel(void) const = 0;
+};
+
+
+namespace {
+
+template <typename T>
+struct OutputRowPtrsElement
+{
+  struct Workspace
+  {
+    T **output_row_ptrs;
+    size_t *output_ld_cols;
+    size_t *output_ld_vls;  // Stride between vectors of channels
+    T *output_padding_buffer;
+  };
+
+  template <typename OutputStage>
+  static size_t get_element_size(const WorkspaceArgs<IPlanarStrategy<OutputStage>, OutputStage> &args)
+  {
+    // We need one pointer and stride for each row of output, and an additional
+    // blob of memory into which padded stores can go.
+    return args.strategy->get_output_rows() * (sizeof(T *) + 2*sizeof(size_t)) +
+           get_vector_length<char>(args.strategy->get_vl_type());
+  }
+
+  template <typename WorkspaceType, typename OutputStage>
+  static void *initialise(WorkspaceType *ws, void *buffer,
+                          const WorkspaceArgs<IPlanarStrategy<OutputStage>, OutputStage> &args)
+  {
+    const auto n_rows = args.strategy->get_output_rows();
+    ws->output_row_ptrs = reinterpret_cast<T **>(buffer);
+    ws->output_ld_cols = reinterpret_cast<size_t *>(ws->output_row_ptrs + n_rows);
+    ws->output_ld_vls = ws->output_ld_cols + n_rows;
+    ws->output_padding_buffer = reinterpret_cast<T *>(ws->output_ld_vls + n_rows);
+    return ws->output_padding_buffer + get_vector_length<T>(args.strategy->get_vl_type());
+  }
+};
+
+}  // namespace {anonymous}
+
+
+template <typename TInput, typename TWeight=TInput, typename TOutput=TInput,
+          typename TAccum=typename DefaultTAccum<TOutput>::Type,
+          typename OutputStage=typename DefaultOutputStage<TOutput>::Type>
+class DepthwisePlanar : public DepthwiseCommon<TInput, TWeight, TOutput>
+{
+  using Parent = DepthwiseCommon<TInput, TWeight, TOutput>;
+  using StrategyType = IPlanarStrategy<OutputStage>;
+  using WorkspaceManager = Workspace<
+    OutputRowPtrsElement<TOutput>,
+    ActivationsElement<TAccum, OutputStage>
+  >;
+  using WorkspaceType = typename WorkspaceManager::WorkspaceType;
+
+  std::unique_ptr<StrategyType> m_strat;
+  const TAccum *m_bias;
+  OutputStage m_os;
+
+  public:
+  DepthwisePlanar(StrategyType *const strat, const DepthwiseArgs &args, const OutputStage &os = {})
+  : Parent(args), m_strat(strat), m_bias(nullptr), m_os(os)
+  {
+  }
+
+  size_t get_storage_size(void) const override
+  {
+    return m_strat->get_storage_size(this->m_args);
+  }
+
+  void pack_parameters(
+    void *buffer, const void *biases,
+    const void *weights, size_t ld_weight_col, size_t ld_weight_row
+  ) override
+  {
+    m_strat->pack_parameters(this->m_args, buffer, biases, {}, weights, ld_weight_col, ld_weight_row);
+    this->m_bias = reinterpret_cast<const TAccum *>(biases);
+    depthwise_depthfirst::stash_bias(this->m_os, biases);
+  }
+
+  size_t get_working_size(unsigned int n_threads, unsigned int) const override
+  {
+    return this->get_working_size_per_thread() * n_threads;
+  }
+
+  protected:
+  /* Compute the amount of working space required for a single thread. */
+  virtual size_t get_working_size_per_thread(void) const
+  {
+    return WorkspaceManager::get_sizeof_workspace(
+      WorkspaceArgs<IPlanarStrategy<OutputStage>, OutputStage>(m_strat.get(), this->m_args, m_os));
+  }
+
+  /* Initialise the working space for a thread. */
+  virtual void initialise_working_space(void *buffer) const
+  {
+    WorkspaceManager::initialise(
+      buffer,
+      WorkspaceArgs<IPlanarStrategy<OutputStage>, OutputStage>(m_strat.get(), this->m_args, m_os)
+    );
+  }
+
+  /* Execute the kernel for a given chunk of work. */
+  virtual void execute_kernel(
+    const TInput *inptr, size_t ld_in_row, size_t ld_in_col, size_t ld_in_vl,
+    unsigned int pad_top, unsigned int valid_input_rows,
+    unsigned int pad_left, unsigned int valid_input_cols,
+    const TWeight *weights, const TAccum *bias,
+    TOutput *outptr, size_t ld_out_row, size_t ld_out_col, size_t ld_out_vl,
+    unsigned int valid_output_rows, unsigned int valid_output_cols,
+    unsigned int first_channel, unsigned int valid_channels,
+    WorkspaceType *ws
+  ) const
+  {
+    // Initialise the output pointers
+    for (auto i = 0u; i < m_strat->get_output_rows(); i++)
+    {
+      // Point at the output tensor for all valid rows; otherwise point at the
+      // padding buffer.
+      ws->output_row_ptrs[i] = i < valid_output_rows ? outptr : ws->output_padding_buffer;
+      ws->output_ld_cols[i] = i < valid_output_rows ? ld_out_col : 0;
+      ws->output_ld_vls[i] = i < valid_output_rows ? ld_out_vl : 0;
+      outptr += ld_out_row;
+    }
+
+    // Execute the kernel
+    PlanarKernelType<TInput, TWeight, TOutput, TAccum, OutputStage>::template execute<WorkspaceType>(
+      reinterpret_cast<const PlanarStrategy<TInput, TWeight, TOutput, TAccum, OutputStage> *>(m_strat.get())->get_kernel(),
+      inptr, ld_in_row, ld_in_col, ld_in_vl,
+      pad_top, valid_input_rows, pad_left, valid_input_cols,
+      weights, bias,
+      ws->output_row_ptrs, ws->output_ld_cols, ws->output_ld_vls,
+      valid_output_cols, first_channel, valid_channels,
+      this->m_os, ws
+    );
+  }
+
+  void execute_internal(
+    unsigned int batches,
+    unsigned int input_height,
+    unsigned int input_width,
+    unsigned int n_input_channels,
+    const PaddingValues &padding,
+    const void *input,
+    size_t ld_input_col,
+    size_t ld_input_row,
+    size_t ld_input_batch,
+    const void *parameters,
+    unsigned int output_height,
+    unsigned int output_width,
+    void *output,
+    size_t ld_output_col,
+    size_t ld_output_row,
+    size_t ld_output_batch,
+    void *working_space,
+    unsigned int thread_id,
+    unsigned int n_threads
+  ) const override
+  {
+    // Get and initialise the working space for this thread.
+    void *thread_working_space =
+      static_cast<uint8_t *>(working_space) + thread_id * this->get_working_size_per_thread();
+    this->initialise_working_space(thread_working_space);
+    auto ws = reinterpret_cast<WorkspaceType *>(thread_working_space);
+
+    const auto n_output_channels = n_input_channels * this->m_args.channel_multiplier;
+    const auto vl = get_vector_length<TAccum>(m_strat->get_vl_type());
+
+    // Get typed pointers
+    auto input_batch = reinterpret_cast<const TInput *>(input);
+    auto output_batch = reinterpret_cast<TOutput *>(output);
+    auto weights = reinterpret_cast<const TWeight *>(parameters);
+
+    // Iterate over batches
+    for (; batches; batches--)
+    {
+      // NOTE: Other loop orderings are possible and it would be worth
+      // investigating them.
+
+      // Within a batch, stripe threads across rows.
+      for (auto start_output_i = thread_id * m_strat->get_output_rows();
+           start_output_i < output_height;
+           start_output_i += n_threads * m_strat->get_output_rows())
+      {
+        // Determine what (if any padding) is required on the top/bottom of
+        // this row of the convolution.
+        const int start_input_i = start_output_i * this->m_args.stride_rows - padding.top;
+        const unsigned int input_pad_top = start_input_i < 0 ? -start_input_i : 0;
+        const unsigned int input_i = start_input_i < 0 ? 0 : start_input_i;
+        const unsigned int valid_input_rows = input_i > input_height ? 0 : input_height - input_i;
+        const unsigned int valid_output_rows = output_height - start_output_i;
+
+        auto inptr_row = input_batch + input_i*ld_input_row;
+        auto outptr_row = output_batch + start_output_i * ld_output_row;
+
+        // Execute the kernel
+        this->execute_kernel(
+          inptr_row, ld_input_row, ld_input_col, vl,
+          input_pad_top, valid_input_rows, padding.left, input_width,
+          weights, this->m_bias,
+          outptr_row, ld_output_row, ld_output_col, vl,
+          valid_output_rows, output_width,
+          0 /* first channel */, n_output_channels,
+          ws
+        );
+      }
+
+      // Update the input and output pointers to account for batch
+      input_batch += ld_input_batch;
+      output_batch += ld_output_batch;
+    }
+  }
+};
+
+}  // namespace depthwise
+}  // namespace arm_conv
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/depthwise_s8q.cpp b/src/core/NEON/kernels/arm_conv/depthwise/depthwise_s8q.cpp
index 46a3118..4ff249a 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/depthwise_s8q.cpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/depthwise_s8q.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -25,15 +25,14 @@
 #include "arm_gemm_local.hpp"
 
 #include "depthwise_implementation.hpp"
-#include "depthwise_depthfirst_quantized.hpp"
-#include "depthwise_depthfirst_generic_quantized.hpp"
-#include "depthwise_depthfirst_multiplier_quantized.hpp"
-#include "depthwise_depthfirst_generic_multiplier_quantized.hpp"
+#include "depthwise_depthfirst.hpp"
+#include "depthwise_depthfirst_generic.hpp"
+#include "depthwise_depthfirst_multiplier.hpp"
 
 #include "depthwise_implementation_constraints.hpp"
 
 #if defined(__aarch64__)
-#if defined(ARM_COMPUTE_ENABLE_SVE) && defined(ARM_COMPUTE_ENABLE_SVE2)
+#if defined(ARM_COMPUTE_ENABLE_SVE)
 #include "kernels/sve_s8qs_nhwc_3x3_s1_output2x2_dot_depthfirst.hpp"
 #include "kernels/sve_s8q_nhwc_3x3_s1_output2x2_dot_depthfirst.hpp"
 #include "kernels/sve_s8q_nhwc_3x3_s1_output2x2_mla_depthfirst.hpp"
@@ -41,7 +40,7 @@
 #include "kernels/sve_s8q_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp"
 #include "kernels/sve_s8q_packed_to_nhwc_3x3_s2_with_multiplier_output2x4_dot_depthfirst.hpp"
 #include "kernels/sve_s8q_packed_to_nhwc_5x5_s1_with_multiplier_output4x2_dot_depthfirst.hpp"
-#endif  // defined(ARM_COMPUTE_ENABLE_SVE) && defined(ARM_COMPUTE_ENABLE_SVE2)
+#endif  // defined(ARM_COMPUTE_ENABLE_SVE)
 #include "kernels/a64_s8qs_nhwc_3x3_s1_output2x2_dot_depthfirst.hpp"
 #include "kernels/a64_s8q_nhwc_3x3_s1_output2x2_dot_depthfirst.hpp"
 #include "kernels/a64_s8q_nhwc_3x3_s1_output2x2_mla_depthfirst.hpp"
@@ -73,7 +72,7 @@
 
 static const DepthwiseImplementation<int8_t, int8_t, int8_t, Requantize32> depthwise_s8q_methods[] = {
 #if defined(__aarch64__)
-#if defined(ARM_COMPUTE_ENABLE_SVE) && defined(ARM_COMPUTE_ENABLE_SVE2)
+#if defined(ARM_COMPUTE_ENABLE_SVE)
   {
     DepthwiseMethod::DEPTHFIRST,
     "sve_s8qs_nhwc_3x3_s1_output2x2_dot_depthfirst",
@@ -84,7 +83,8 @@
                              cpu_has_sve2),
     nullptr,
     [] (const DepthwiseArgs &args, const Requantize32 &qp) -> DepthwiseCommon<int8_t, int8_t, int8_t> * {
-      return new DepthwiseDepthfirstQuantized<sve_s8qs_nhwc_3x3_s1_output2x2_dot_depthfirst>(args, qp);
+      auto strat = new sve_s8qs_nhwc_3x3_s1_output2x2_dot_depthfirst(args.cpu_info);
+      return new DepthwiseDepthfirst<int8_t>(strat, args, qp);
     },
   },
   {
@@ -96,7 +96,8 @@
                              cpu_has_sve2),
     nullptr,
     [] (const DepthwiseArgs &args, const Requantize32 &qp) -> DepthwiseCommon<int8_t, int8_t, int8_t> * {
-      return new DepthwiseDepthfirstQuantized<sve_s8q_nhwc_3x3_s1_output2x2_dot_depthfirst>(args, qp);
+      auto strat = new sve_s8q_nhwc_3x3_s1_output2x2_dot_depthfirst(args.cpu_info);
+      return new DepthwiseDepthfirst<int8_t>(strat, args, qp);
     },
   },
   {
@@ -108,7 +109,8 @@
                              cpu_has_sve2),
     nullptr,
     [] (const DepthwiseArgs &args, const Requantize32 &qp) -> DepthwiseCommon<int8_t, int8_t, int8_t> * {
-      return new DepthwiseDepthfirstQuantized<sve_s8q_nhwc_3x3_s1_output2x2_mla_depthfirst>(args, qp);
+      auto strat = new sve_s8q_nhwc_3x3_s1_output2x2_mla_depthfirst(args.cpu_info);
+      return new DepthwiseDepthfirst<int8_t>(strat, args, qp);
     },
   },
   {
@@ -120,7 +122,8 @@
                              cpu_has_sve2),
     nullptr,
     [] (const DepthwiseArgs &args, const Requantize32 &qp) -> DepthwiseCommon<int8_t, int8_t, int8_t> * {
-      return new DepthwiseDepthfirstQuantized<sve_s8q_nhwc_3x3_s2_output2x2_mla_depthfirst>(args, qp);
+      auto strat = new sve_s8q_nhwc_3x3_s2_output2x2_mla_depthfirst(args.cpu_info);
+      return new DepthwiseDepthfirst<int8_t>(strat, args, qp);
     },
   },
   {
@@ -132,7 +135,8 @@
                              cpu_has_sve2),
     nullptr,
     [] (const DepthwiseArgs &args, const Requantize32 &qp) -> DepthwiseCommon<int8_t, int8_t, int8_t> * {
-      return new DepthwiseDepthfirstQuantized<sve_s8q_nhwc_5x5_s1_output2x2_mla_depthfirst>(args, qp);
+      auto strat = new sve_s8q_nhwc_5x5_s1_output2x2_mla_depthfirst(args.cpu_info);
+      return new DepthwiseDepthfirst<int8_t>(strat, args, qp);
     },
   },
   {
@@ -140,10 +144,12 @@
     "sve_s8q_packed_to_nhwc_3x3_s2_with_multiplier_output2x4_dot_depthfirst",
     constraint<Requantize32>(is_supported<sve_s8q_packed_to_nhwc_3x3_s2_with_multiplier_output2x4_dot_depthfirst>,
                              qp_has_no_left_shift,
+                             has_channel_multiplier,
                              cpu_has_sve2),
     nullptr,
     [] (const DepthwiseArgs &args, const Requantize32 &qp) -> DepthwiseCommon<int8_t, int8_t, int8_t> * {
-      return new DepthwiseDepthfirstWithMultiplierQuantized<sve_s8q_packed_to_nhwc_3x3_s2_with_multiplier_output2x4_dot_depthfirst>(args, qp);
+      auto strat = new sve_s8q_packed_to_nhwc_3x3_s2_with_multiplier_output2x4_dot_depthfirst(args.cpu_info);
+      return new DepthwiseDepthfirstMultiplier<int8_t, int8_t, int8_t, int32_t, false>(strat, args, qp);
     },
   },
   {
@@ -151,13 +157,15 @@
     "sve_s8q_packed_to_nhwc_5x5_s1_with_multiplier_output4x2_dot_depthfirst",
     constraint<Requantize32>(is_supported<sve_s8q_packed_to_nhwc_5x5_s1_with_multiplier_output4x2_dot_depthfirst>,
                              qp_has_no_left_shift,
+                             has_channel_multiplier,
                              cpu_has_sve2),
     nullptr,
     [] (const DepthwiseArgs &args, const Requantize32 &qp) -> DepthwiseCommon<int8_t, int8_t, int8_t> * {
-      return new DepthwiseDepthfirstWithMultiplierQuantized<sve_s8q_packed_to_nhwc_5x5_s1_with_multiplier_output4x2_dot_depthfirst>(args, qp);
+      auto strat = new sve_s8q_packed_to_nhwc_5x5_s1_with_multiplier_output4x2_dot_depthfirst(args.cpu_info);
+      return new DepthwiseDepthfirstMultiplier<int8_t, int8_t, int8_t, int32_t, false>(strat, args, qp);
     },
   },
-#endif  // defined(ARM_COMPUTE_ENABLE_SVE) && defined(ARM_COMPUTE_ENABLE_SVE2)
+#endif  // defined(ARM_COMPUTE_ENABLE_SVE)
   {
     DepthwiseMethod::DEPTHFIRST,
     "a64_s8qs_nhwc_3x3_s1_output2x2_dot_depthfirst",
@@ -168,7 +176,8 @@
                              cpu_has_dot_product),
     nullptr,
     [] (const DepthwiseArgs &args, const Requantize32 &qp) -> DepthwiseCommon<int8_t, int8_t, int8_t> * {
-      return new DepthwiseDepthfirstQuantized<a64_s8qs_nhwc_3x3_s1_output2x2_dot_depthfirst>(args, qp);
+      auto strat = new a64_s8qs_nhwc_3x3_s1_output2x2_dot_depthfirst(args.cpu_info);
+      return new DepthwiseDepthfirst<int8_t>(strat, args, qp);
     },
   },
   {
@@ -180,7 +189,8 @@
                              cpu_has_dot_product),
     nullptr,
     [] (const DepthwiseArgs &args, const Requantize32 &qp) -> DepthwiseCommon<int8_t, int8_t, int8_t> * {
-      return new DepthwiseDepthfirstQuantized<a64_s8q_nhwc_3x3_s1_output2x2_dot_depthfirst>(args, qp);
+      auto strat = new a64_s8q_nhwc_3x3_s1_output2x2_dot_depthfirst(args.cpu_info);
+      return new DepthwiseDepthfirst<int8_t>(strat, args, qp);
     },
   },
   {
@@ -191,7 +201,8 @@
                              qp_has_no_left_shift),
     nullptr,
     [] (const DepthwiseArgs &args, const Requantize32 &qp) -> DepthwiseCommon<int8_t, int8_t, int8_t> * {
-      return new DepthwiseDepthfirstQuantized<a64_s8q_nhwc_3x3_s1_output2x2_mla_depthfirst>(args, qp);
+      auto strat = new a64_s8q_nhwc_3x3_s1_output2x2_mla_depthfirst(args.cpu_info);
+      return new DepthwiseDepthfirst<int8_t>(strat, args, qp);
     },
   },
   {
@@ -202,7 +213,8 @@
                              qp_has_no_left_shift),
     nullptr,
     [] (const DepthwiseArgs &args, const Requantize32 &qp) -> DepthwiseCommon<int8_t, int8_t, int8_t> * {
-      return new DepthwiseDepthfirstQuantized<a64_s8q_nhwc_3x3_s2_output2x2_mla_depthfirst>(args, qp);
+      auto strat = new a64_s8q_nhwc_3x3_s2_output2x2_mla_depthfirst(args.cpu_info);
+      return new DepthwiseDepthfirst<int8_t>(strat, args, qp);
     },
   },
   {
@@ -213,7 +225,8 @@
                              qp_has_no_left_shift),
     nullptr,
     [] (const DepthwiseArgs &args, const Requantize32 &qp) -> DepthwiseCommon<int8_t, int8_t, int8_t> * {
-      return new DepthwiseDepthfirstQuantized<a64_s8q_nhwc_5x5_s1_output2x2_mla_depthfirst>(args, qp);
+      auto strat = new a64_s8q_nhwc_5x5_s1_output2x2_mla_depthfirst(args.cpu_info);
+      return new DepthwiseDepthfirst<int8_t>(strat, args, qp);
     },
   },
   {
@@ -222,7 +235,9 @@
     constraint<Requantize32>(has_no_channel_multiplier),
     nullptr,
     [] (const DepthwiseArgs &args, const Requantize32 &qp) -> DepthwiseCommon<int8_t, int8_t, int8_t> * {
-      return new DepthwiseDepthfirstGenericQuantized<a64_s8q_nhwc_generic_output9_mla_depthfirst, 3, 3>(args, qp);
+      auto kernel = new a64_s8q_nhwc_generic_output9_mla_depthfirst(args.cpu_info);
+      auto strat = new GenericDepthfirstStrategy<int8_t>(kernel, 3, 3, args);
+      return new DepthwiseDepthfirstGeneric<int8_t>(strat, args, qp);
     },
   },
   {
@@ -230,10 +245,12 @@
     "a64_s8q_packed_to_nhwc_3x3_s2_with_multiplier_output2x4_dot_depthfirst",
     constraint<Requantize32>(is_supported<a64_s8q_packed_to_nhwc_3x3_s2_with_multiplier_output2x4_dot_depthfirst>,
                              qp_has_no_left_shift,
+                             has_channel_multiplier,
                              cpu_has_dot_product),
     nullptr,
     [] (const DepthwiseArgs &args, const Requantize32 &qp) -> DepthwiseCommon<int8_t, int8_t, int8_t> * {
-      return new DepthwiseDepthfirstWithMultiplierQuantized<a64_s8q_packed_to_nhwc_3x3_s2_with_multiplier_output2x4_dot_depthfirst>(args, qp);
+      auto strat = new a64_s8q_packed_to_nhwc_3x3_s2_with_multiplier_output2x4_dot_depthfirst(args.cpu_info);
+      return new DepthwiseDepthfirstMultiplier<int8_t, int8_t, int8_t, int32_t, false>(strat, args, qp);
     },
   },
   {
@@ -241,19 +258,23 @@
     "a64_s8q_packed_to_nhwc_5x5_s1_with_multiplier_output4x2_dot_depthfirst",
     constraint<Requantize32>(is_supported<a64_s8q_packed_to_nhwc_5x5_s1_with_multiplier_output4x2_dot_depthfirst>,
                              qp_has_no_left_shift,
+                             has_channel_multiplier,
                              cpu_has_dot_product),
     nullptr,
     [] (const DepthwiseArgs &args, const Requantize32 &qp) -> DepthwiseCommon<int8_t, int8_t, int8_t> * {
-      return new DepthwiseDepthfirstWithMultiplierQuantized<a64_s8q_packed_to_nhwc_5x5_s1_with_multiplier_output4x2_dot_depthfirst>(args, qp);
+      auto strat = new a64_s8q_packed_to_nhwc_5x5_s1_with_multiplier_output4x2_dot_depthfirst(args.cpu_info);
+      return new DepthwiseDepthfirstMultiplier<int8_t, int8_t, int8_t, int32_t, false>(strat, args, qp);
     },
   },
   {
     DepthwiseMethod::DEPTHFIRST,
     "a64_s8q_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst",
-    nullptr,
+    constraint<Requantize32>(has_channel_multiplier),
     nullptr,
     [] (const DepthwiseArgs &args, const Requantize32 &qp) -> DepthwiseCommon<int8_t, int8_t, int8_t> * {
-      return new DepthwiseDepthfirstGenericWithMultiplierQuantized<a64_s8q_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst>(args, qp);
+      auto kern = new a64_s8q_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst(args.cpu_info);
+      auto strat = new GenericDepthfirstMultiplierStrategy<int8_t>(kern, args);
+      return new DepthwiseDepthfirstMultiplier<int8_t, int8_t, int8_t, int32_t, true>(strat, args, qp);
     },
   },
 #endif  // defined(__aarch64__)
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/depthwise_strategies_common.cpp b/src/core/NEON/kernels/arm_conv/depthwise/depthwise_strategies_common.cpp
new file mode 100644
index 0000000..33f2177
--- /dev/null
+++ b/src/core/NEON/kernels/arm_conv/depthwise/depthwise_strategies_common.cpp
@@ -0,0 +1,60 @@
+/*
+ * Copyright (c) 2022 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "depthwise_strategies_common.hpp"
+
+namespace arm_conv {
+namespace depthwise {
+
+unsigned int DepthfirstStrategyUntyped::get_input_rows() const
+{
+  return this->get_kernel_rows() + (this->get_output_rows() - 1) * this->get_stride_rows();
+}
+
+unsigned int DepthfirstStrategyUntyped::get_input_cols() const
+{
+  return this->get_kernel_cols() + (this->get_output_cols() - 1) * this->get_stride_cols();
+}
+
+unsigned int DepthfirstStrategyUntyped::get_n_input_points() const { return this->get_input_rows() * this->get_input_cols(); }
+unsigned int DepthfirstStrategyUntyped::get_n_output_points() const { return this->get_output_rows() * this->get_output_cols(); }
+unsigned int DepthfirstStrategyUntyped::get_n_kernel_points() const { return this->get_kernel_rows() * this->get_kernel_cols(); }
+
+unsigned int DepthfirstStrategyUntyped::get_accumulator_depth_vl() const { return 1; }
+
+bool DepthfirstStrategyUntyped::get_kernel_packing_point(const unsigned int index, unsigned int &x, unsigned int &y) const
+{
+  // Get the kernel point to pack at the given index; return false to
+  // indicate that this index, and all greater indices, is out of range.
+  if (index < (this->get_kernel_cols() * this->get_kernel_rows()))
+  {
+    y = index % this->get_kernel_cols();
+    x = index / this->get_kernel_cols();
+    return true;
+  }
+  return false;
+}
+
+}  // namespace depthwise
+}  // namespace arm_conv
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/depthwise_strategies_common.hpp b/src/core/NEON/kernels/arm_conv/depthwise/depthwise_strategies_common.hpp
new file mode 100644
index 0000000..99b91fb
--- /dev/null
+++ b/src/core/NEON/kernels/arm_conv/depthwise/depthwise_strategies_common.hpp
@@ -0,0 +1,95 @@
+/*
+ * Copyright (c) 2022 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+
+#include "src/core/NEON/kernels/arm_gemm/utils.hpp"
+#include "interleaves/generic.hpp"
+#include "depthfirst_driver.hpp"
+
+namespace arm_conv {
+namespace depthwise {
+
+class DepthfirstStrategyUntyped : public IDepthfirstStrategy
+{
+  public:
+  virtual arm_gemm::VLType get_vl_type() const = 0;
+
+  virtual unsigned int get_kernel_rows() const = 0;
+  virtual unsigned int get_kernel_cols() const = 0;
+
+  virtual unsigned int get_stride_rows() const = 0;
+  virtual unsigned int get_stride_cols() const = 0;
+
+  virtual unsigned int get_input_rows() const override;
+  virtual unsigned int get_input_cols() const override;
+
+  virtual unsigned int get_n_input_points() const;
+  virtual unsigned int get_n_output_points() const;
+  virtual unsigned int get_n_kernel_points() const;
+
+  // Get the number of VLs used in the accumulator, this defaults to 1.
+  virtual unsigned int get_accumulator_depth_vl() const;
+
+  // Get the order in which to pack the weights, this defaults to a row-major
+  // sweep over the weight tensor.
+  virtual bool get_kernel_packing_point(const unsigned int index, unsigned int &x, unsigned int &y) const;
+};
+
+template <typename TInput, typename TWeight, typename TOutput, typename TAccum, typename OutputStage>
+class DepthfirstStrategy : public DepthfirstStrategyUntyped
+{
+  public:
+  virtual size_t get_storage_size(const DepthwiseArgs &args) const
+  {
+    interleaves::PackingArguments packing_args(
+      this->get_kernel_rows(), this->get_kernel_cols(), sizeof(TWeight),
+      true, sizeof(TAccum),
+      this->get_vl_type(), sizeof(TAccum), this->get_accumulator_depth_vl(),
+      [this] (unsigned int idx, unsigned int &x, unsigned int &y) -> bool
+      { return this->get_kernel_packing_point(idx, x, y); }
+    );
+    return interleaves::get_storage_size_generic(packing_args, args);
+  }
+
+  virtual void pack_parameters(
+    const DepthwiseArgs &args, void *buffer,
+    const void *biases, const OutputStage &,
+    const void *weights, size_t ld_weight_col, size_t ld_weight_row
+  ) const
+  {
+    interleaves::PackingArguments packing_args(
+      this->get_kernel_rows(), this->get_kernel_cols(), sizeof(TWeight),
+      true, sizeof(TAccum),
+      this->get_vl_type(), sizeof(TAccum), this->get_accumulator_depth_vl(),
+      [this] (unsigned int idx, unsigned int &x, unsigned int &y) -> bool
+      { return this->get_kernel_packing_point(idx, x, y); }
+    );
+    interleaves::pack_parameters_generic(
+      packing_args, args, buffer, biases, weights, ld_weight_col, ld_weight_row);
+  }
+};
+
+}  // namespace depthwise
+}  // namespace arm_conv
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/depthwise_u8q.cpp b/src/core/NEON/kernels/arm_conv/depthwise/depthwise_u8q.cpp
index 67713c5..b1489d0 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/depthwise_u8q.cpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/depthwise_u8q.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -25,23 +25,27 @@
 #include "arm_gemm_local.hpp"
 
 #include "depthwise_implementation.hpp"
-#include "depthwise_depthfirst_quantized.hpp"
-#include "depthwise_depthfirst_generic_quantized.hpp"
-#include "depthwise_depthfirst_multiplier_quantized.hpp"
-#include "depthwise_depthfirst_generic_multiplier_quantized.hpp"
+#include "depthwise_depthfirst.hpp"
+#include "depthwise_depthfirst_generic.hpp"
+#include "depthwise_depthfirst_multiplier.hpp"
 
 #include "depthwise_implementation_constraints.hpp"
 
 #if defined(__aarch64__)
-#if defined(ARM_COMPUTE_ENABLE_SVE) && defined(ARM_COMPUTE_ENABLE_SVE2)
+#if defined(ARM_COMPUTE_ENABLE_SVE)
 #include "kernels/sve_u8q_nhwc_3x3_s1_output2x2_dot_depthfirst.hpp"
 #include "kernels/sve_u8q_nhwc_3x3_s1_output2x2_mla_depthfirst.hpp"
 #include "kernels/sve_u8q_nhwc_3x3_s2_output2x2_mla_depthfirst.hpp"
 #include "kernels/sve_u8q_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp"
 #include "kernels/sve_u8q_packed_to_nhwc_3x3_s2_with_multiplier_output2x4_dot_depthfirst.hpp"
 #include "kernels/sve_u8q_packed_to_nhwc_5x5_s1_with_multiplier_output4x2_dot_depthfirst.hpp"
-#endif  // defined(ARM_COMPUTE_ENABLE_SVE) && defined(ARM_COMPUTE_ENABLE_SVE2)
+#endif  // defined(ARM_COMPUTE_ENABLE_SVE)
 #include "kernels/a64_u8q_nhwc_3x3_s1_output2x2_dot_depthfirst.hpp"
+
+#include "kernels/a64_u8qa_nhwc_3x3_s1_output2x2_mla_depthfirst.hpp"
+#include "kernels/a64_u8qa_nhwc_3x3_s2_output2x2_mla_depthfirst.hpp"
+#include "kernels/a64_u8qa_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp"
+
 #include "kernels/a64_u8q_nhwc_3x3_s1_output2x2_mla_depthfirst.hpp"
 #include "kernels/a64_u8q_nhwc_3x3_s2_output2x2_mla_depthfirst.hpp"
 #include "kernels/a64_u8q_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp"
@@ -49,6 +53,7 @@
 #include "kernels/a64_u8q_packed_to_nhwc_3x3_s2_with_multiplier_output2x4_dot_depthfirst.hpp"
 #include "kernels/a64_u8q_packed_to_nhwc_5x5_s1_with_multiplier_output4x2_dot_depthfirst.hpp"
 #include "kernels/a64_u8q_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst.hpp"
+
 #endif  // defined(__aarch64__)
 
 #include <cstdint>
@@ -60,7 +65,7 @@
 
 static const DepthwiseImplementation<uint8_t, uint8_t, uint8_t, Requantize32> depthwise_u8q_methods[] = {
 #if defined(__aarch64__)
-#if defined(ARM_COMPUTE_ENABLE_SVE) && defined(ARM_COMPUTE_ENABLE_SVE2)
+#if defined(ARM_COMPUTE_ENABLE_SVE)
   {
     DepthwiseMethod::DEPTHFIRST,
     "sve_u8q_nhwc_3x3_s1_output2x2_dot_depthfirst",
@@ -70,7 +75,8 @@
                              cpu_has_sve2),
     nullptr,
     [] (const DepthwiseArgs &args, const Requantize32 &qp) -> DepthwiseCommon<uint8_t, uint8_t, uint8_t> * {
-      return new DepthwiseDepthfirstQuantized<sve_u8q_nhwc_3x3_s1_output2x2_dot_depthfirst>(args, qp);
+      auto strat = new sve_u8q_nhwc_3x3_s1_output2x2_dot_depthfirst(args.cpu_info);
+      return new DepthwiseDepthfirst<uint8_t>(strat, args, qp);
     },
   },
   {
@@ -82,7 +88,8 @@
                              cpu_has_sve2),
     nullptr,
     [] (const DepthwiseArgs &args, const Requantize32 &qp) -> DepthwiseCommon<uint8_t, uint8_t, uint8_t> * {
-      return new DepthwiseDepthfirstQuantized<sve_u8q_nhwc_3x3_s1_output2x2_mla_depthfirst>(args, qp);
+      auto strat = new sve_u8q_nhwc_3x3_s1_output2x2_mla_depthfirst(args.cpu_info);
+      return new DepthwiseDepthfirst<uint8_t>(strat, args, qp);
     },
   },
   {
@@ -94,7 +101,8 @@
                              cpu_has_sve2),
     nullptr,
     [] (const DepthwiseArgs &args, const Requantize32 &qp) -> DepthwiseCommon<uint8_t, uint8_t, uint8_t> * {
-      return new DepthwiseDepthfirstQuantized<sve_u8q_nhwc_3x3_s2_output2x2_mla_depthfirst>(args, qp);
+      auto strat = new sve_u8q_nhwc_3x3_s2_output2x2_mla_depthfirst(args.cpu_info);
+      return new DepthwiseDepthfirst<uint8_t>(strat, args, qp);
     },
   },
   {
@@ -106,7 +114,8 @@
                              cpu_has_sve2),
     nullptr,
     [] (const DepthwiseArgs &args, const Requantize32 &qp) -> DepthwiseCommon<uint8_t, uint8_t, uint8_t> * {
-      return new DepthwiseDepthfirstQuantized<sve_u8q_nhwc_5x5_s1_output2x2_mla_depthfirst>(args, qp);
+      auto strat = new sve_u8q_nhwc_5x5_s1_output2x2_mla_depthfirst(args.cpu_info);
+      return new DepthwiseDepthfirst<uint8_t>(strat, args, qp);
     },
   },
   {
@@ -114,10 +123,12 @@
     "sve_u8q_packed_to_nhwc_3x3_s2_with_multiplier_output2x4_dot_depthfirst",
     constraint<Requantize32>(is_supported<sve_u8q_packed_to_nhwc_3x3_s2_with_multiplier_output2x4_dot_depthfirst>,
                              qp_has_no_left_shift,
+                             has_channel_multiplier,
                              cpu_has_sve2),
     nullptr,
     [] (const DepthwiseArgs &args, const Requantize32 &qp) -> DepthwiseCommon<uint8_t, uint8_t, uint8_t> * {
-      return new DepthwiseDepthfirstWithMultiplierQuantized<sve_u8q_packed_to_nhwc_3x3_s2_with_multiplier_output2x4_dot_depthfirst>(args, qp);
+      auto strat = new sve_u8q_packed_to_nhwc_3x3_s2_with_multiplier_output2x4_dot_depthfirst(args.cpu_info);
+      return new DepthwiseDepthfirstMultiplier<uint8_t, uint8_t, uint8_t, int32_t, false>(strat, args, qp);
     },
   },
   {
@@ -125,13 +136,15 @@
     "sve_u8q_packed_to_nhwc_5x5_s1_with_multiplier_output4x2_dot_depthfirst",
     constraint<Requantize32>(is_supported<sve_u8q_packed_to_nhwc_5x5_s1_with_multiplier_output4x2_dot_depthfirst>,
                              qp_has_no_left_shift,
+                             has_channel_multiplier,
                              cpu_has_sve2),
     nullptr,
     [] (const DepthwiseArgs &args, const Requantize32 &qp) -> DepthwiseCommon<uint8_t, uint8_t, uint8_t> * {
-      return new DepthwiseDepthfirstWithMultiplierQuantized<sve_u8q_packed_to_nhwc_5x5_s1_with_multiplier_output4x2_dot_depthfirst>(args, qp);
+      auto strat = new sve_u8q_packed_to_nhwc_5x5_s1_with_multiplier_output4x2_dot_depthfirst(args.cpu_info);
+      return new DepthwiseDepthfirstMultiplier<uint8_t, uint8_t, uint8_t, int32_t, false>(strat, args, qp);
     },
   },
-#endif  // defined(ARM_COMPUTE_ENABLE_SVE) && defined(ARM_COMPUTE_ENABLE_SVE2)
+#endif  // defined(ARM_COMPUTE_ENABLE_SVE)
   {
     DepthwiseMethod::DEPTHFIRST,
     "a64_u8q_nhwc_3x3_s1_output2x2_dot_depthfirst",
@@ -141,9 +154,51 @@
                              qp_has_no_left_shift),
     nullptr,
     [] (const DepthwiseArgs &args, const Requantize32 &qp) -> DepthwiseCommon<uint8_t, uint8_t, uint8_t> * {
-      return new DepthwiseDepthfirstQuantized<a64_u8q_nhwc_3x3_s1_output2x2_dot_depthfirst>(args, qp);
+      auto strat = new a64_u8q_nhwc_3x3_s1_output2x2_dot_depthfirst(args.cpu_info);
+      return new DepthwiseDepthfirst<uint8_t>(strat, args, qp);
     },
   },
+
+  {
+    DepthwiseMethod::DEPTHFIRST,
+    "a64_u8qa_nhwc_3x3_s1_output2x2_mla_depthfirst",
+    constraint<Requantize32>(is_supported<a64_u8qa_nhwc_3x3_s1_output2x2_mla_depthfirst>,
+                             has_no_channel_multiplier,
+                             qp_zero_a_offset,
+                             qp_has_no_left_shift),
+    nullptr,
+    [] (const DepthwiseArgs &args, const Requantize32 &qp) -> DepthwiseCommon<uint8_t, uint8_t, uint8_t> * {
+      auto strat = new a64_u8qa_nhwc_3x3_s1_output2x2_mla_depthfirst(args.cpu_info);
+      return new DepthwiseDepthfirst<uint8_t>(strat, args, qp);
+    },
+  },
+  {
+    DepthwiseMethod::DEPTHFIRST,
+    "a64_u8qa_nhwc_3x3_s2_output2x2_mla_depthfirst",
+    constraint<Requantize32>(is_supported<a64_u8qa_nhwc_3x3_s2_output2x2_mla_depthfirst>,
+                             has_no_channel_multiplier,
+                             qp_zero_a_offset,
+                             qp_has_no_left_shift),
+    nullptr,
+    [] (const DepthwiseArgs &args, const Requantize32 &qp) -> DepthwiseCommon<uint8_t, uint8_t, uint8_t> * {
+      auto strat = new a64_u8qa_nhwc_3x3_s2_output2x2_mla_depthfirst(args.cpu_info);
+      return new DepthwiseDepthfirst<uint8_t>(strat, args, qp);
+    },
+  },
+  {
+    DepthwiseMethod::DEPTHFIRST,
+    "a64_u8qa_nhwc_5x5_s1_output2x2_mla_depthfirst",
+    constraint<Requantize32>(is_supported<a64_u8qa_nhwc_5x5_s1_output2x2_mla_depthfirst>,
+                             has_no_channel_multiplier,
+                             qp_zero_a_offset,
+                             qp_has_no_left_shift),
+    nullptr,
+    [] (const DepthwiseArgs &args, const Requantize32 &qp) -> DepthwiseCommon<uint8_t, uint8_t, uint8_t> * {
+      auto strat = new a64_u8qa_nhwc_5x5_s1_output2x2_mla_depthfirst(args.cpu_info);
+      return new DepthwiseDepthfirst<uint8_t>(strat, args, qp);
+    },
+  },
+
   {
     DepthwiseMethod::DEPTHFIRST,
     "a64_u8q_nhwc_3x3_s1_output2x2_mla_depthfirst",
@@ -152,7 +207,8 @@
                              qp_has_no_left_shift),
     nullptr,
     [] (const DepthwiseArgs &args, const Requantize32 &qp) -> DepthwiseCommon<uint8_t, uint8_t, uint8_t> * {
-      return new DepthwiseDepthfirstQuantized<a64_u8q_nhwc_3x3_s1_output2x2_mla_depthfirst>(args, qp);
+      auto strat = new a64_u8q_nhwc_3x3_s1_output2x2_mla_depthfirst(args.cpu_info);
+      return new DepthwiseDepthfirst<uint8_t>(strat, args, qp);
     },
   },
   {
@@ -163,7 +219,8 @@
                              qp_has_no_left_shift),
     nullptr,
     [] (const DepthwiseArgs &args, const Requantize32 &qp) -> DepthwiseCommon<uint8_t, uint8_t, uint8_t> * {
-      return new DepthwiseDepthfirstQuantized<a64_u8q_nhwc_3x3_s2_output2x2_mla_depthfirst>(args, qp);
+      auto strat = new a64_u8q_nhwc_3x3_s2_output2x2_mla_depthfirst(args.cpu_info);
+      return new DepthwiseDepthfirst<uint8_t>(strat, args, qp);
     },
   },
   {
@@ -174,7 +231,8 @@
                              qp_has_no_left_shift),
     nullptr,
     [] (const DepthwiseArgs &args, const Requantize32 &qp) -> DepthwiseCommon<uint8_t, uint8_t, uint8_t> * {
-      return new DepthwiseDepthfirstQuantized<a64_u8q_nhwc_5x5_s1_output2x2_mla_depthfirst>(args, qp);
+      auto strat = new a64_u8q_nhwc_5x5_s1_output2x2_mla_depthfirst(args.cpu_info);
+      return new DepthwiseDepthfirst<uint8_t>(strat, args, qp);
     },
   },
   {
@@ -183,7 +241,9 @@
     constraint<Requantize32>(has_no_channel_multiplier),
     nullptr,
     [] (const DepthwiseArgs &args, const Requantize32 &qp) -> DepthwiseCommon<uint8_t, uint8_t, uint8_t> * {
-      return new DepthwiseDepthfirstGenericQuantized<a64_u8q_nhwc_generic_output9_mla_depthfirst, 3, 3>(args, qp);
+      auto kernel = new a64_u8q_nhwc_generic_output9_mla_depthfirst(args.cpu_info);
+      auto strat = new GenericDepthfirstStrategy<uint8_t>(kernel, 3, 3, args);
+      return new DepthwiseDepthfirstGeneric<uint8_t>(strat, args, qp);
     },
   },
   {
@@ -191,10 +251,12 @@
     "a64_u8q_packed_to_nhwc_3x3_s2_with_multiplier_output2x4_dot_depthfirst",
     constraint<Requantize32>(is_supported<a64_u8q_packed_to_nhwc_3x3_s2_with_multiplier_output2x4_dot_depthfirst>,
                              cpu_has_dot_product,
+                             has_channel_multiplier,
                              qp_has_no_left_shift),
     nullptr,
     [] (const DepthwiseArgs &args, const Requantize32 &qp) -> DepthwiseCommon<uint8_t, uint8_t, uint8_t> * {
-      return new DepthwiseDepthfirstWithMultiplierQuantized<a64_u8q_packed_to_nhwc_3x3_s2_with_multiplier_output2x4_dot_depthfirst>(args, qp);
+      auto strat = new a64_u8q_packed_to_nhwc_3x3_s2_with_multiplier_output2x4_dot_depthfirst(args.cpu_info);
+      return new DepthwiseDepthfirstMultiplier<uint8_t, uint8_t, uint8_t, int32_t, false>(strat, args, qp);
     },
   },
   {
@@ -202,21 +264,26 @@
     "a64_u8q_packed_to_nhwc_5x5_s1_with_multiplier_output4x2_dot_depthfirst",
     constraint<Requantize32>(is_supported<a64_u8q_packed_to_nhwc_5x5_s1_with_multiplier_output4x2_dot_depthfirst>,
                              cpu_has_dot_product,
+                             has_channel_multiplier,
                              qp_has_no_left_shift),
     nullptr,
     [] (const DepthwiseArgs &args, const Requantize32 &qp) -> DepthwiseCommon<uint8_t, uint8_t, uint8_t> * {
-      return new DepthwiseDepthfirstWithMultiplierQuantized<a64_u8q_packed_to_nhwc_5x5_s1_with_multiplier_output4x2_dot_depthfirst>(args, qp);
+      auto strat = new a64_u8q_packed_to_nhwc_5x5_s1_with_multiplier_output4x2_dot_depthfirst(args.cpu_info);
+      return new DepthwiseDepthfirstMultiplier<uint8_t, uint8_t, uint8_t, int32_t, false>(strat, args, qp);
     },
   },
   {
     DepthwiseMethod::DEPTHFIRST,
     "a64_u8q_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst",
-    nullptr,
+    constraint<Requantize32>(has_channel_multiplier),
     nullptr,
     [] (const DepthwiseArgs &args, const Requantize32 &qp) -> DepthwiseCommon<uint8_t, uint8_t, uint8_t> * {
-      return new DepthwiseDepthfirstGenericWithMultiplierQuantized<a64_u8q_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst>(args, qp);
+      auto kern = new a64_u8q_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst(args.cpu_info);
+      auto strat = new GenericDepthfirstMultiplierStrategy<uint8_t>(kern, args);
+      return new DepthwiseDepthfirstMultiplier<uint8_t, uint8_t, uint8_t, int32_t, true>(strat, args, qp);
     },
   },
+
 #endif  // defined(__aarch64__)
   { DepthwiseMethod::DEFAULT, "", nullptr, nullptr, nullptr },  // End of list
 };
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/depthwise_u8s8u8q.cpp b/src/core/NEON/kernels/arm_conv/depthwise/depthwise_u8s8u8q.cpp
index af4426b..9b98901 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/depthwise_u8s8u8q.cpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/depthwise_u8s8u8q.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -25,19 +25,18 @@
 #include "arm_gemm_local.hpp"
 
 #include "depthwise_implementation.hpp"
-#include "depthwise_depthfirst_quantized.hpp"
-#include "depthwise_depthfirst_generic_quantized.hpp"
-#include "depthwise_depthfirst_multiplier_quantized.hpp"
-#include "depthwise_depthfirst_generic_multiplier_quantized.hpp"
+#include "depthwise_depthfirst.hpp"
+#include "depthwise_depthfirst_generic.hpp"
+#include "depthwise_depthfirst_multiplier.hpp"
 
 #include "depthwise_implementation_constraints.hpp"
 
 #if defined(__aarch64__)
-#if defined(ARM_COMPUTE_ENABLE_SVE) && defined(ARM_COMPUTE_ENABLE_SVE2)
+#if defined(ARM_COMPUTE_ENABLE_SVE)
 #include "kernels/sve_u8s8u8q_nhwc_3x3_s1_output2x2_mla_depthfirst.hpp"
 #include "kernels/sve_u8s8u8q_nhwc_3x3_s2_output2x2_mla_depthfirst.hpp"
 #include "kernels/sve_u8s8u8q_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp"
-#endif  // defined(ARM_COMPUTE_ENABLE_SVE) && defined(ARM_COMPUTE_ENABLE_SVE2)
+#endif  // defined(ARM_COMPUTE_ENABLE_SVE)
 #include "kernels/a64_u8s8u8q_nhwc_3x3_s1_output2x2_mla_depthfirst.hpp"
 #include "kernels/a64_u8s8u8q_nhwc_3x3_s2_output2x2_mla_depthfirst.hpp"
 #include "kernels/a64_u8s8u8q_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp"
@@ -54,7 +53,7 @@
 
 static const DepthwiseImplementation<uint8_t, int8_t, uint8_t, Requantize32> depthwise_u8q_methods[] = {
 #if defined(__aarch64__)
-#if defined(ARM_COMPUTE_ENABLE_SVE) && defined(ARM_COMPUTE_ENABLE_SVE2)
+#if defined(ARM_COMPUTE_ENABLE_SVE)
   {
     DepthwiseMethod::DEPTHFIRST,
     "sve_u8s8u8q_nhwc_3x3_s1_output2x2_mla_depthfirst",
@@ -64,7 +63,8 @@
                              cpu_has_sve2),
     nullptr,
     [] (const DepthwiseArgs &args, const Requantize32 &qp) -> DepthwiseCommon<uint8_t, int8_t, uint8_t> * {
-      return new DepthwiseDepthfirstQuantized<sve_u8s8u8q_nhwc_3x3_s1_output2x2_mla_depthfirst>(args, qp);
+      auto strat = new sve_u8s8u8q_nhwc_3x3_s1_output2x2_mla_depthfirst(args.cpu_info);
+      return new DepthwiseDepthfirst<uint8_t, int8_t>(strat, args, qp);
     },
   },
   {
@@ -76,7 +76,8 @@
                              cpu_has_sve2),
     nullptr,
     [] (const DepthwiseArgs &args, const Requantize32 &qp) -> DepthwiseCommon<uint8_t, int8_t, uint8_t> * {
-      return new DepthwiseDepthfirstQuantized<sve_u8s8u8q_nhwc_3x3_s2_output2x2_mla_depthfirst>(args, qp);
+      auto strat = new sve_u8s8u8q_nhwc_3x3_s2_output2x2_mla_depthfirst(args.cpu_info);
+      return new DepthwiseDepthfirst<uint8_t, int8_t>(strat, args, qp);
     },
   },
   {
@@ -88,10 +89,11 @@
                              cpu_has_sve2),
     nullptr,
     [] (const DepthwiseArgs &args, const Requantize32 &qp) -> DepthwiseCommon<uint8_t, int8_t, uint8_t> * {
-      return new DepthwiseDepthfirstQuantized<sve_u8s8u8q_nhwc_5x5_s1_output2x2_mla_depthfirst>(args, qp);
+      auto strat = new sve_u8s8u8q_nhwc_5x5_s1_output2x2_mla_depthfirst(args.cpu_info);
+      return new DepthwiseDepthfirst<uint8_t, int8_t>(strat, args, qp);
     },
   },
-#endif  // defined(ARM_COMPUTE_ENABLE_SVE) && defined(ARM_COMPUTE_ENABLE_SVE2)
+#endif  // defined(ARM_COMPUTE_ENABLE_SVE)
   {
     DepthwiseMethod::DEPTHFIRST,
     "a64_u8s8u8q_nhwc_3x3_s1_output2x2_mla_depthfirst",
@@ -100,7 +102,8 @@
                              qp_has_no_left_shift),
     nullptr,
     [] (const DepthwiseArgs &args, const Requantize32 &qp) -> DepthwiseCommon<uint8_t, int8_t, uint8_t> * {
-      return new DepthwiseDepthfirstQuantized<a64_u8s8u8q_nhwc_3x3_s1_output2x2_mla_depthfirst>(args, qp);
+      auto strat = new a64_u8s8u8q_nhwc_3x3_s1_output2x2_mla_depthfirst(args.cpu_info);
+      return new DepthwiseDepthfirst<uint8_t, int8_t>(strat, args, qp);
     },
   },
   {
@@ -111,7 +114,8 @@
                              qp_has_no_left_shift),
     nullptr,
     [] (const DepthwiseArgs &args, const Requantize32 &qp) -> DepthwiseCommon<uint8_t, int8_t, uint8_t> * {
-      return new DepthwiseDepthfirstQuantized<a64_u8s8u8q_nhwc_3x3_s2_output2x2_mla_depthfirst>(args, qp);
+      auto strat = new a64_u8s8u8q_nhwc_3x3_s2_output2x2_mla_depthfirst(args.cpu_info);
+      return new DepthwiseDepthfirst<uint8_t, int8_t>(strat, args, qp);
     },
   },
   {
@@ -122,7 +126,8 @@
                              qp_has_no_left_shift),
     nullptr,
     [] (const DepthwiseArgs &args, const Requantize32 &qp) -> DepthwiseCommon<uint8_t, int8_t, uint8_t> * {
-      return new DepthwiseDepthfirstQuantized<a64_u8s8u8q_nhwc_5x5_s1_output2x2_mla_depthfirst>(args, qp);
+      auto strat = new a64_u8s8u8q_nhwc_5x5_s1_output2x2_mla_depthfirst(args.cpu_info);
+      return new DepthwiseDepthfirst<uint8_t, int8_t>(strat, args, qp);
     },
   },
   {
@@ -131,16 +136,20 @@
     constraint<Requantize32>(has_no_channel_multiplier),
     nullptr,
     [] (const DepthwiseArgs &args, const Requantize32 &qp) -> DepthwiseCommon<uint8_t, int8_t, uint8_t> * {
-      return new DepthwiseDepthfirstGenericQuantized<a64_u8s8u8q_nhwc_generic_output9_mla_depthfirst, 3, 3>(args, qp);
+      auto kernel = new a64_u8s8u8q_nhwc_generic_output9_mla_depthfirst(args.cpu_info);
+      auto strat = new GenericDepthfirstStrategy<uint8_t, int8_t>(kernel, 3, 3, args);
+      return new DepthwiseDepthfirstGeneric<uint8_t, int8_t>(strat, args, qp);
     },
   },
   {
     DepthwiseMethod::DEPTHFIRST,
     "a64_u8s8u8q_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst",
-    nullptr,
+    constraint<Requantize32>(has_channel_multiplier),
     nullptr,
     [] (const DepthwiseArgs &args, const Requantize32 &qp) -> DepthwiseCommon<uint8_t, int8_t, uint8_t> * {
-      return new DepthwiseDepthfirstGenericWithMultiplierQuantized<a64_u8s8u8q_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst>(args, qp);
+      auto kern = new a64_u8s8u8q_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst(args.cpu_info);
+      auto strat = new GenericDepthfirstMultiplierStrategy<uint8_t, int8_t>(kern, args);
+      return new DepthwiseDepthfirstMultiplier<uint8_t, int8_t, uint8_t, int32_t, true>(strat, args, qp);
     },
   },
 #endif  // defined(__aarch64__)
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/interleaves/generic.cpp b/src/core/NEON/kernels/arm_conv/depthwise/interleaves/generic.cpp
new file mode 100644
index 0000000..056f08d
--- /dev/null
+++ b/src/core/NEON/kernels/arm_conv/depthwise/interleaves/generic.cpp
@@ -0,0 +1,150 @@
+/*
+ * Copyright (c) 2022 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "generic.hpp"
+
+#include <functional>
+
+namespace arm_conv {
+namespace depthwise {
+namespace interleaves {
+
+PackingArguments::PackingArguments(
+  unsigned int kernel_rows, unsigned int kernel_cols, size_t weight_element_size,
+  bool include_bias, size_t bias_element_size,
+  arm_gemm::VLType vl_type, size_t accumulator_element_size, unsigned int accumulator_depth_vl,
+  std::function<bool(unsigned int, unsigned int &, unsigned int &)> get_weight_pos
+) : kernel_rows(kernel_rows), kernel_cols(kernel_cols), weight_element_size(weight_element_size),
+    include_bias(include_bias), bias_element_size(bias_element_size),
+    vl_type(vl_type), accumulator_element_size(accumulator_element_size), accumulator_depth_vl(accumulator_depth_vl),
+    get_weight_pos(get_weight_pos)
+{
+}
+
+size_t get_storage_size_generic(const PackingArguments &packing_args, const DepthwiseArgs &args)
+{
+  // If the channel multiplier is greater than one, then we treat this as a
+  // repeated packing of `channel_multiplier`-sized problems.
+  if (args.channel_multiplier > 1)
+  {
+    DepthwiseArgs args_per_input_channel(args);
+    args_per_input_channel.input_channels = args.channel_multiplier;
+    args_per_input_channel.channel_multiplier = 1;
+
+    return args.input_channels * get_storage_size_generic(packing_args, args_per_input_channel);
+  }
+
+  const unsigned int vl =
+    packing_args.accumulator_depth_vl *
+    arm_gemm::utils::get_vector_length<uint8_t>(packing_args.vl_type) / packing_args.accumulator_element_size;
+  const unsigned int n_packs = arm_gemm::iceildiv(args.input_channels, vl);
+  const auto pack_size = (packing_args.include_bias ? packing_args.bias_element_size : 0) +
+                         packing_args.kernel_points() * packing_args.weight_element_size;
+  return n_packs * pack_size * vl;
+}
+
+void pack_parameters_generic(
+  const PackingArguments &packing_args,
+  const DepthwiseArgs &args,
+  void *buffer_raw,
+  const void *biases_raw,
+  const void *weights_raw,
+  size_t ld_weight_col,
+  size_t ld_weight_row
+)
+{
+  // Cast the pointers to byte sizes
+  auto *buffer = static_cast<uint8_t *>(buffer_raw);
+  auto *biases = static_cast<const uint8_t *>(biases_raw);
+  auto *weights = static_cast<const uint8_t *>(weights_raw);
+
+  // If the channel multiplier is greater than one, then we treat this as a
+  // repeated packing of `channel_multiplier`-sized problems.
+  if (args.channel_multiplier > 1)
+  {
+    // Get a modified copy of the depthwise arguments
+    DepthwiseArgs args_per_input_channel(args);
+    args_per_input_channel.input_channels = args.channel_multiplier;
+    args_per_input_channel.channel_multiplier = 1;
+
+    // Resolve the strides here
+    ld_weight_col = ld_weight_col ? ld_weight_col : args.input_channels * args.channel_multiplier;
+    ld_weight_row = ld_weight_row ? ld_weight_row : ld_weight_col * packing_args.kernel_cols;
+
+    auto per_input_channel_size = get_storage_size_generic(packing_args, args_per_input_channel);
+
+    for (unsigned int c = 0; c < args.input_channels; c++)
+    {
+      pack_parameters_generic(
+        packing_args, args_per_input_channel, buffer, biases, weights, ld_weight_col, ld_weight_row);
+
+      // Update the pointers
+      buffer += per_input_channel_size;
+      biases += (biases == nullptr) ? 0 : packing_args.bias_element_size * args.channel_multiplier;
+      weights += packing_args.weight_element_size * args.channel_multiplier;
+    }
+    return;
+  }
+
+  // Finalise the weight strides
+  ld_weight_col = (ld_weight_col == 0) ? args.input_channels : ld_weight_col;
+  ld_weight_row = (ld_weight_row == 0) ? packing_args.kernel_cols * ld_weight_col : ld_weight_row;
+
+  const unsigned int vl =
+    packing_args.accumulator_depth_vl *
+    arm_gemm::utils::get_vector_length<uint8_t>(packing_args.vl_type) / packing_args.accumulator_element_size;
+
+  for (unsigned int n = 0; n < args.input_channels; n += vl)
+  {
+    const unsigned int todo = std::min(vl, args.input_channels - n);
+
+    if (packing_args.include_bias)
+    {
+      if (biases != nullptr)
+      {
+        memcpy(buffer, biases, todo * packing_args.bias_element_size);
+        biases += todo * packing_args.bias_element_size;
+      }
+      else
+      {
+        memset(buffer, 0, vl * packing_args.bias_element_size);
+      }
+
+      buffer += vl * packing_args.bias_element_size;
+    }
+
+    // Copy each of the weights in turn
+    unsigned int kx, ky;
+    for (int kindex = 0; packing_args.get_weight_pos(kindex, kx, ky); kindex++)
+    {
+      const auto src_ptr = weights + (kx*ld_weight_row + ky*ld_weight_col + n) * packing_args.weight_element_size;
+      memcpy(buffer, src_ptr, todo * packing_args.weight_element_size);
+      buffer += vl * packing_args.weight_element_size;
+    }
+  }
+}
+
+}  // namespace interleaves
+}  // namespace depthwise
+}  // namespace arm_conv
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/interleaves/generic.hpp b/src/core/NEON/kernels/arm_conv/depthwise/interleaves/generic.hpp
new file mode 100644
index 0000000..5b5ae17
--- /dev/null
+++ b/src/core/NEON/kernels/arm_conv/depthwise/interleaves/generic.hpp
@@ -0,0 +1,80 @@
+/*
+ * Copyright (c) 2022 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+
+#include "src/core/NEON/kernels/arm_gemm/utils.hpp"
+#include "depthwise.hpp"
+
+#include <functional>
+
+namespace arm_conv {
+namespace depthwise {
+namespace interleaves {
+
+struct PackingArguments
+{
+  const unsigned int kernel_rows;
+  const unsigned int kernel_cols;
+  const size_t weight_element_size;
+  const bool include_bias;
+  const size_t bias_element_size;
+  arm_gemm::VLType vl_type;
+  const size_t accumulator_element_size;
+  const unsigned int accumulator_depth_vl;
+  std::function<bool(unsigned int, unsigned int &, unsigned int &)> get_weight_pos;
+
+  unsigned int kernel_points(void) const { return kernel_cols * kernel_rows; }
+
+  PackingArguments(
+    unsigned int kernel_rows,
+    unsigned int kernel_cols,
+    size_t weight_element_size,
+    bool include_bias,
+    size_t bias_element_size,
+    arm_gemm::VLType vl_type,
+    size_t accumulator_element_size,
+    unsigned int accumulator_depth_vl,
+    std::function<bool(unsigned int, unsigned int &, unsigned int &)> get_weight_pos
+  );
+};
+
+size_t get_storage_size_generic(
+  const PackingArguments &packing_args,
+  const DepthwiseArgs &args
+);
+
+void pack_parameters_generic(
+  const PackingArguments &packing_args,
+  const DepthwiseArgs &args,
+  void *buffer_raw,
+  const void *biases_raw,
+  const void *weights_raw,
+  size_t ld_weight_col,
+  size_t ld_weight_row
+);
+
+}  // namespace interleaves
+}  // namespace depthwise
+}  // namespace arm_conv
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/interleaves/generic_quantized_dot_product.cpp b/src/core/NEON/kernels/arm_conv/depthwise/interleaves/generic_quantized_dot_product.cpp
new file mode 100644
index 0000000..a638905
--- /dev/null
+++ b/src/core/NEON/kernels/arm_conv/depthwise/interleaves/generic_quantized_dot_product.cpp
@@ -0,0 +1,161 @@
+/*
+ * Copyright (c) 2022 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "generic_quantized_dot_product.hpp"
+#include <cstdint>
+
+namespace arm_conv {
+namespace depthwise {
+namespace interleaves {
+namespace quantized {
+
+size_t get_storage_size(
+  const DepthwiseArgs &args,
+  const arm_gemm::VLType vl_type,
+  const unsigned int accumulator_depth_vl
+)
+{
+  // We produce VL<int32_t> channels at a time, for each of these blocks of
+  // channels we store a vector of biases, weights (complicated) and
+  // requantize parameters.
+  const unsigned int iter_length = accumulator_depth_vl * arm_gemm::utils::get_vector_length<int32_t>(vl_type);
+  const unsigned int n_iters = args.input_channels * arm_gemm::iceildiv(args.channel_multiplier, iter_length);
+
+  // Compute the cost of storing the weights
+  const unsigned int n_dots_per_kernel_row = arm_gemm::iceildiv(args.kernel_cols, 4u);
+
+  return n_iters * iter_length * (
+    sizeof(int32_t) +  // Bias
+    4 * n_dots_per_kernel_row * args.kernel_rows * sizeof(int8_t) +  // Weights
+    2 * sizeof(int32_t)  // Requantisation parameters
+  );
+}
+
+template <typename T>
+void pack_parameters(
+  void *_buffer, const int32_t *biases,
+  const T *weights, size_t ld_weight_col, size_t ld_weight_row,
+  const DepthwiseArgs &args,
+  const arm_gemm::Requantize32 &qp,
+  const arm_gemm::VLType vl_type,
+  const unsigned int accumulator_depth_vl
+)
+{
+  auto buffer = static_cast<uint8_t *>(_buffer);
+  auto requant_muls = qp.per_channel_muls;
+  auto requant_shifts = qp.per_channel_right_shifts;
+
+  const unsigned int iter_length = accumulator_depth_vl * arm_gemm::utils::get_vector_length<int32_t>(vl_type);
+  const unsigned int n_iters_per_input_channel = arm_gemm::iceildiv(args.channel_multiplier, iter_length);
+  const unsigned int n_dots_per_kernel_row = arm_gemm::iceildiv(args.kernel_cols, 4u);
+
+  const size_t iter_stride = iter_length * (
+      sizeof(int32_t) +  // Bias
+      4 * n_dots_per_kernel_row * args.kernel_rows * sizeof(T) +  // Weights
+      2 * sizeof(int32_t)  // Requantisation parameters
+  );
+
+  ld_weight_col = (ld_weight_col == 0) ? args.input_channels * args.channel_multiplier : ld_weight_col;
+  ld_weight_row = (ld_weight_row == 0) ? args.kernel_cols * ld_weight_col : ld_weight_row;
+
+  for (unsigned int input_channel = 0; input_channel < args.input_channels; input_channel++)
+  {
+    auto buffer_input_channel = buffer + input_channel * n_iters_per_input_channel * iter_stride;
+    auto weights_input_channel = weights + input_channel * args.channel_multiplier;
+
+    for (unsigned int iter = 0; iter < n_iters_per_input_channel; iter++)
+    {
+      // Get a pointer to the start of this portion of the buffer; consequently
+      // derive pointers to the bias, weight and requantisation portions of
+      // this frame.
+      auto buffer_base = buffer_input_channel + iter_stride * iter;
+      auto buffer_biases = reinterpret_cast<int32_t *>(buffer_base);
+      auto buffer_weights = buffer_base + sizeof(int32_t) * iter_length;
+      auto buffer_requant_mul = reinterpret_cast<int32_t *>(
+        buffer_weights + args.kernel_rows * n_dots_per_kernel_row * 4 * iter_length);
+      auto buffer_requant_shift = buffer_requant_mul + iter_length;
+      auto weights_base = weights_input_channel + iter * iter_length;
+
+      // Hence work through the data for this iteration, on a
+      // channel-by-channel basis.
+      const auto this_iter_length = std::min<unsigned int>(
+        iter_length, args.channel_multiplier - iter * iter_length
+      );
+      for (unsigned int i = 0; i < this_iter_length; i++)
+      {
+        auto weights_channel = weights_base + i;
+
+        // Read the bias value, we modify this as we read the weights.
+        auto bias_value = biases == nullptr ? 0 : *(biases++);
+        int32_t elements_sum = 0;
+
+        // Read through the kernel; for each row, marshal together as many dot
+        // product terms as are required.
+        for (unsigned int ki = 0; ki < args.kernel_rows; ki++)
+        {
+          auto buffer_row = buffer_weights + i*4 + ki * 4 * n_dots_per_kernel_row * iter_length;
+          auto weights_row = weights_channel + ki * ld_weight_row;
+
+          unsigned int kj = 0;
+          for (; kj < args.kernel_cols; kj++)
+          {
+            // Determine which element to which we're writing
+            const auto dot = kj / 4;
+            const auto elem = kj % 4;
+
+            // Copy the value; include in the sum
+            const auto val = weights_row[kj * ld_weight_col];
+            buffer_row[dot * 4 * iter_length + elem] = val;
+            elements_sum += val;
+          }
+          for (; kj < 4 * n_dots_per_kernel_row; kj++)
+          {
+            const auto dot = kj / 4;
+            const auto elem = kj % 4;
+            buffer_row[dot * 4 * iter_length + elem] = 0;
+          }
+
+          buffer_row += 4 * n_dots_per_kernel_row * iter_length;
+        }
+
+        // Write back the bias and offset values
+        *(buffer_biases++) =
+          bias_value - qp.a_offset * elements_sum +
+          args.kernel_rows * args.kernel_cols * qp.a_offset * qp.b_offset;
+
+        // Write out the requantisation parameters
+        *(buffer_requant_mul++) = qp.per_channel_requant ? *(requant_muls++) : qp.per_layer_mul;
+        *(buffer_requant_shift++) = qp.per_channel_requant ? *(requant_shifts++) : qp.per_layer_right_shift;
+      }
+    }
+  }
+}
+
+template void pack_parameters(void *, const int32_t *, const int8_t *, size_t, size_t, const DepthwiseArgs &, const arm_gemm::Requantize32 &, arm_gemm::VLType, unsigned int);
+template void pack_parameters(void *, const int32_t *, const uint8_t *, size_t, size_t, const DepthwiseArgs &, const arm_gemm::Requantize32 &, arm_gemm::VLType, unsigned int);
+
+}  // namespace quantized
+}  // namespace interleaves
+}  // namespace depthwise
+}  // namespace arm_conv
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/interleaves/generic_quantized_dot_product.hpp b/src/core/NEON/kernels/arm_conv/depthwise/interleaves/generic_quantized_dot_product.hpp
new file mode 100644
index 0000000..779d67d
--- /dev/null
+++ b/src/core/NEON/kernels/arm_conv/depthwise/interleaves/generic_quantized_dot_product.hpp
@@ -0,0 +1,53 @@
+/*
+ * Copyright (c) 2022 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+
+#include "generic.hpp"
+
+namespace arm_conv {
+namespace depthwise {
+namespace interleaves {
+namespace quantized {
+
+size_t get_storage_size(
+  const DepthwiseArgs &args,
+  arm_gemm::VLType vl_type,
+  unsigned int accumulator_depth_vl=1
+);
+
+template <typename T>
+void pack_parameters(
+  void *buffer, const int32_t *biases,
+  const T *weights, size_t ld_weight_col, size_t ld_weight_row,
+  const DepthwiseArgs &args,
+  const arm_gemm::Requantize32 &qp,
+  arm_gemm::VLType vl_type,
+  unsigned int accumulator_depth_vl
+);
+
+}  // namespace quantized
+}  // namespace interleaves
+}  // namespace depthwise
+}  // namespace arm_conv
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/interleaves/list.hpp b/src/core/NEON/kernels/arm_conv/depthwise/interleaves/list.hpp
index cb49a24..76f38eb 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/interleaves/list.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/interleaves/list.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -29,90 +29,30 @@
 
 #if defined(ARM_COMPUTE_ENABLE_SVE)
 
-class interleave_sve_u8q_3x3_dot
+struct interleave_sve_u8q_3x3_dot
 {
-  public:
-    static void pack_parameters(unsigned int, void *, const int32_t *, const uint8_t *, const arm_gemm::Requantize32 &, size_t, size_t);
-    static size_t get_packed_size(const DepthwiseArgs &);
+  static void pack_parameters(unsigned int, void *, const int32_t *, const uint8_t *, const arm_gemm::Requantize32 &, size_t, size_t);
+  static size_t get_packed_size(const DepthwiseArgs &);
 };
 
-class interleave_sve_s8q_3x3_dot
+struct interleave_sve_s8q_3x3_dot
 {
-  public:
-    static void pack_parameters(unsigned int, void *, const int32_t *, const int8_t *, const arm_gemm::Requantize32 &, size_t, size_t);
-    static size_t get_packed_size(const DepthwiseArgs &);
-};
-
-class interleave_sve_u8q_3x3_mla
-{
-  public:
-    static void pack_parameters(unsigned int, void *, const uint8_t *, size_t, size_t);
-    static size_t get_packed_size(const DepthwiseArgs &);
-};
-
-class interleave_sve_s8q_3x3_mla
-{
-  public:
-    static void pack_parameters(unsigned int, void *, const int8_t *, size_t, size_t);
-    static size_t get_packed_size(const DepthwiseArgs &);
-};
-
-class interleave_sve_u8q_5x5_mla
-{
-  public:
-    static void pack_parameters(unsigned int, void *, const uint8_t *, size_t, size_t);
-    static size_t get_packed_size(const DepthwiseArgs &);
-};
-
-class interleave_sve_s8q_5x5_mla
-{
-  public:
-    static void pack_parameters(unsigned int, void *, const int8_t *, size_t, size_t);
-    static size_t get_packed_size(const DepthwiseArgs &);
+  static void pack_parameters(unsigned int, void *, const int32_t *, const int8_t *, const arm_gemm::Requantize32 &, size_t, size_t);
+  static size_t get_packed_size(const DepthwiseArgs &);
 };
 
 #endif  // defined(ARM_COMPUTE_ENABLE_SVE)
 
-class interleave_a64_u8q_3x3_dot
+struct interleave_a64_u8q_3x3_dot
 {
-  public:
-    static void pack_parameters(unsigned int, void *, const int32_t *, const uint8_t *, const arm_gemm::Requantize32 &, size_t, size_t);
-    static size_t get_packed_size(const DepthwiseArgs &);
+  static void pack_parameters(unsigned int, void *, const int32_t *, const uint8_t *, const arm_gemm::Requantize32 &, size_t, size_t);
+  static size_t get_packed_size(const DepthwiseArgs &);
 };
 
-class interleave_a64_s8q_3x3_dot
+struct interleave_a64_s8q_3x3_dot
 {
-  public:
-    static void pack_parameters(unsigned int, void *, const int32_t *, const int8_t *, const arm_gemm::Requantize32 &, size_t, size_t);
-    static size_t get_packed_size(const DepthwiseArgs &);
-};
-
-class interleave_a64_u8q_3x3_mla
-{
-  public:
-    static void pack_parameters(unsigned int, void *, const uint8_t *, size_t, size_t);
-    static size_t get_packed_size(const DepthwiseArgs &);
-};
-
-class interleave_a64_s8q_3x3_mla
-{
-  public:
-    static void pack_parameters(unsigned int, void *, const int8_t *, size_t, size_t);
-    static size_t get_packed_size(const DepthwiseArgs &);
-};
-
-class interleave_a64_u8q_5x5_mla
-{
-  public:
-    static void pack_parameters(unsigned int, void *, const uint8_t *, size_t, size_t);
-    static size_t get_packed_size(const DepthwiseArgs &);
-};
-
-class interleave_a64_s8q_5x5_mla
-{
-  public:
-    static void pack_parameters(unsigned int, void *, const int8_t *, size_t, size_t);
-    static size_t get_packed_size(const DepthwiseArgs &);
+  static void pack_parameters(unsigned int, void *, const int32_t *, const int8_t *, const arm_gemm::Requantize32 &, size_t, size_t);
+  static size_t get_packed_size(const DepthwiseArgs &);
 };
 
 }  // namespace depthwise
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp16_nhwc_3x3_s1_output2x2_mla_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp16_nhwc_3x3_s1_output2x2_mla_depthfirst.hpp
index be50d1c..d2db125 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp16_nhwc_3x3_s1_output2x2_mla_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp16_nhwc_3x3_s1_output2x2_mla_depthfirst.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -36,19 +36,16 @@
 void a64_fp16_nhwc_3x3_s1_output2x2_mla_depthfirst_indirect_impl(const __fp16 *const *const, __fp16 *const *const, const void *, unsigned int, const __fp16, const __fp16);
 void a64_fp16_nhwc_3x3_s1_output2x2_mla_depthfirst_direct_impl(const unsigned int, const unsigned int, const __fp16 *, int64_t, int64_t, __fp16 *, int64_t, int64_t, const void *, unsigned int, const __fp16, const __fp16);
 
-class a64_fp16_nhwc_3x3_s1_output2x2_mla_depthfirst : public IDepthwiseDepthfirstStrategy
+class a64_fp16_nhwc_3x3_s1_output2x2_mla_depthfirst : public DepthwiseDepthfirstStrategy<__fp16, __fp16, __fp16, __fp16>
 {
   private:
-  typedef void (*indirect_kern_type)(const __fp16 *const *const, __fp16 *const *const, const void *, unsigned int, const __fp16, const __fp16);
-  indirect_kern_type m_indirect_kernel = a64_fp16_nhwc_3x3_s1_output2x2_mla_depthfirst_indirect_impl;
-
-  typedef void (*direct_kern_type)(const unsigned int, const unsigned int, const __fp16 *, int64_t, int64_t, __fp16 *, int64_t, int64_t, const void *, unsigned int, const __fp16, const __fp16);
-  direct_kern_type m_direct_kernel = a64_fp16_nhwc_3x3_s1_output2x2_mla_depthfirst_direct_impl;
+  using Parent = DepthwiseDepthfirstStrategy<__fp16, __fp16, __fp16, __fp16>;
+  Parent::IndirectKernelType m_indirect_kernel = a64_fp16_nhwc_3x3_s1_output2x2_mla_depthfirst_indirect_impl;
+  Parent::DirectKernelType m_direct_kernel = a64_fp16_nhwc_3x3_s1_output2x2_mla_depthfirst_direct_impl;
 
   public:
-  typedef __fp16 return_type;
-
-  constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::None;
+  using return_type = __fp16;
+  constexpr static auto vl_type = arm_gemm::VLType::None;
 
   constexpr static unsigned int kernel_rows = 3;
   constexpr static unsigned int kernel_cols = 3;
@@ -59,60 +56,13 @@
   constexpr static unsigned int output_rows = 2;
   constexpr static unsigned int output_cols = 2;
 
-  constexpr static unsigned int input_rows = 4;
-  constexpr static unsigned int input_cols = 4;
-
-  a64_fp16_nhwc_3x3_s1_output2x2_mla_depthfirst(const CPUInfo *) {}
+  a64_fp16_nhwc_3x3_s1_output2x2_mla_depthfirst(const CPUInfo *)
+  : DepthwiseDepthfirstStrategy<__fp16, __fp16, __fp16, __fp16>(2, 3, 1) {}
 
   arm_gemm::VLType get_vl_type(void) const override { return vl_type; }
 
-  unsigned int get_kernel_rows(void) const override { return kernel_rows; }
-  unsigned int get_kernel_cols(void) const override { return kernel_cols; }
-
-  unsigned int get_stride_rows(void) const override { return stride_rows; }
-  unsigned int get_stride_cols(void) const override { return stride_cols; }
-
-  unsigned int get_output_rows(void) const override { return output_rows; }
-  unsigned int get_output_cols(void) const override { return output_cols; }
-
-  unsigned int get_input_rows(void) const override { return input_rows; }
-  unsigned int get_input_cols(void) const override { return input_cols; }
-
-  void indirect_kernel(
-    const void *const *const input_ptrs,
-    void *const *const outptrs,
-    const void *params,
-    unsigned int n_channels,
-    const void *activation_min,
-    const void *activation_max
-  ) const override
-  {
-    m_indirect_kernel(
-      reinterpret_cast<const __fp16 *const *>(input_ptrs),
-      reinterpret_cast<__fp16 *const *>(outptrs),
-      params, n_channels,
-      *static_cast<const __fp16 *>(activation_min),
-      *static_cast<const __fp16 *>(activation_max)
-    );
-  }
-
-  void direct_kernel(
-    const unsigned int n_tile_rows, const unsigned int n_tile_cols,
-    const void *inptr, int64_t ld_input_row, int64_t ld_input_col,
-    void *outptr, int64_t ld_output_row, int64_t ld_output_col,
-    const void *params, unsigned int n_channels,
-    const void *activation_min, const void *activation_max
-  ) const override
-  {
-    m_direct_kernel(
-      n_tile_rows, n_tile_cols,
-      static_cast<const __fp16 *>(inptr), ld_input_row, ld_input_col,
-      static_cast<__fp16 *>(outptr), ld_output_row, ld_output_col,
-      params, n_channels,
-      *static_cast<const __fp16 *>(activation_min),
-      *static_cast<const __fp16 *>(activation_max)
-    );
-  }
+  Parent::IndirectKernelType get_indirect_kernel() const override { return m_indirect_kernel; }
+  Parent::DirectKernelType get_direct_kernel() const override { return m_direct_kernel; }
 };
 
 }  // namespace depthwise
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp16_nhwc_3x3_s1_output3x3_mla_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp16_nhwc_3x3_s1_output3x3_mla_depthfirst.hpp
index 39fa7f6..75368df 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp16_nhwc_3x3_s1_output3x3_mla_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp16_nhwc_3x3_s1_output3x3_mla_depthfirst.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -36,19 +36,16 @@
 void a64_fp16_nhwc_3x3_s1_output3x3_mla_depthfirst_indirect_impl(const __fp16 *const *const, __fp16 *const *const, const void *, unsigned int, const __fp16, const __fp16);
 void a64_fp16_nhwc_3x3_s1_output3x3_mla_depthfirst_direct_impl(const unsigned int, const unsigned int, const __fp16 *, int64_t, int64_t, __fp16 *, int64_t, int64_t, const void *, unsigned int, const __fp16, const __fp16);
 
-class a64_fp16_nhwc_3x3_s1_output3x3_mla_depthfirst : public IDepthwiseDepthfirstStrategy
+class a64_fp16_nhwc_3x3_s1_output3x3_mla_depthfirst : public DepthwiseDepthfirstStrategy<__fp16, __fp16, __fp16, __fp16>
 {
   private:
-  typedef void (*indirect_kern_type)(const __fp16 *const *const, __fp16 *const *const, const void *, unsigned int, const __fp16, const __fp16);
-  indirect_kern_type m_indirect_kernel = a64_fp16_nhwc_3x3_s1_output3x3_mla_depthfirst_indirect_impl;
-
-  typedef void (*direct_kern_type)(const unsigned int, const unsigned int, const __fp16 *, int64_t, int64_t, __fp16 *, int64_t, int64_t, const void *, unsigned int, const __fp16, const __fp16);
-  direct_kern_type m_direct_kernel = a64_fp16_nhwc_3x3_s1_output3x3_mla_depthfirst_direct_impl;
+  using Parent = DepthwiseDepthfirstStrategy<__fp16, __fp16, __fp16, __fp16>;
+  Parent::IndirectKernelType m_indirect_kernel = a64_fp16_nhwc_3x3_s1_output3x3_mla_depthfirst_indirect_impl;
+  Parent::DirectKernelType m_direct_kernel = a64_fp16_nhwc_3x3_s1_output3x3_mla_depthfirst_direct_impl;
 
   public:
-  typedef __fp16 return_type;
-
-  constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::None;
+  using return_type = __fp16;
+  constexpr static auto vl_type = arm_gemm::VLType::None;
 
   constexpr static unsigned int kernel_rows = 3;
   constexpr static unsigned int kernel_cols = 3;
@@ -59,60 +56,13 @@
   constexpr static unsigned int output_rows = 3;
   constexpr static unsigned int output_cols = 3;
 
-  constexpr static unsigned int input_rows = 5;
-  constexpr static unsigned int input_cols = 5;
-
-  a64_fp16_nhwc_3x3_s1_output3x3_mla_depthfirst(const CPUInfo *) {}
+  a64_fp16_nhwc_3x3_s1_output3x3_mla_depthfirst(const CPUInfo *)
+  : DepthwiseDepthfirstStrategy<__fp16, __fp16, __fp16, __fp16>(3, 3, 1) {}
 
   arm_gemm::VLType get_vl_type(void) const override { return vl_type; }
 
-  unsigned int get_kernel_rows(void) const override { return kernel_rows; }
-  unsigned int get_kernel_cols(void) const override { return kernel_cols; }
-
-  unsigned int get_stride_rows(void) const override { return stride_rows; }
-  unsigned int get_stride_cols(void) const override { return stride_cols; }
-
-  unsigned int get_output_rows(void) const override { return output_rows; }
-  unsigned int get_output_cols(void) const override { return output_cols; }
-
-  unsigned int get_input_rows(void) const override { return input_rows; }
-  unsigned int get_input_cols(void) const override { return input_cols; }
-
-  void indirect_kernel(
-    const void *const *const input_ptrs,
-    void *const *const outptrs,
-    const void *params,
-    unsigned int n_channels,
-    const void *activation_min,
-    const void *activation_max
-  ) const override
-  {
-    m_indirect_kernel(
-      reinterpret_cast<const __fp16 *const *>(input_ptrs),
-      reinterpret_cast<__fp16 *const *>(outptrs),
-      params, n_channels,
-      *static_cast<const __fp16 *>(activation_min),
-      *static_cast<const __fp16 *>(activation_max)
-    );
-  }
-
-  void direct_kernel(
-    const unsigned int n_tile_rows, const unsigned int n_tile_cols,
-    const void *inptr, int64_t ld_input_row, int64_t ld_input_col,
-    void *outptr, int64_t ld_output_row, int64_t ld_output_col,
-    const void *params, unsigned int n_channels,
-    const void *activation_min, const void *activation_max
-  ) const override
-  {
-    m_direct_kernel(
-      n_tile_rows, n_tile_cols,
-      static_cast<const __fp16 *>(inptr), ld_input_row, ld_input_col,
-      static_cast<__fp16 *>(outptr), ld_output_row, ld_output_col,
-      params, n_channels,
-      *static_cast<const __fp16 *>(activation_min),
-      *static_cast<const __fp16 *>(activation_max)
-    );
-  }
+  Parent::IndirectKernelType get_indirect_kernel() const override { return m_indirect_kernel; }
+  Parent::DirectKernelType get_direct_kernel() const override { return m_direct_kernel; }
 };
 
 }  // namespace depthwise
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp16_nhwc_3x3_s1_output3x3_mla_depthfirst/generic_indirect.cpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp16_nhwc_3x3_s1_output3x3_mla_depthfirst/generic_indirect.cpp
index e0abca9..faf6c91 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp16_nhwc_3x3_s1_output3x3_mla_depthfirst/generic_indirect.cpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp16_nhwc_3x3_s1_output3x3_mla_depthfirst/generic_indirect.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -1283,7 +1283,8 @@
 
     :
     : [n_channels] "r" ((unsigned long) n_channels), [offsetof_Args_inptrs] "I" (offsetof(Args, inptrs)), [offsetof_args_max] "I" (offsetof(Args, max)), [offsetof_args_min] "I" (offsetof(Args, min)), [offsetof_args_outptrs] "I" (offsetof(Args, outptrs)), [offsetof_args_params] "I" (offsetof(Args, params)), [params_struct] "r" (&params_struct)
-    : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v16", "v17", "v18", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17", "x19", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28" );
+    : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v16", "v17", "v18", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17", "x19", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"
+  );
 }
 
 }  // namespace depthwise
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp16_nhwc_3x3_s1_output4x4_mla_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp16_nhwc_3x3_s1_output4x4_mla_depthfirst.hpp
index 1e0d922..4f0de6b 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp16_nhwc_3x3_s1_output4x4_mla_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp16_nhwc_3x3_s1_output4x4_mla_depthfirst.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -36,19 +36,16 @@
 void a64_fp16_nhwc_3x3_s1_output4x4_mla_depthfirst_indirect_impl(const __fp16 *const *const, __fp16 *const *const, const void *, unsigned int, const __fp16, const __fp16);
 void a64_fp16_nhwc_3x3_s1_output4x4_mla_depthfirst_direct_impl(const unsigned int, const unsigned int, const __fp16 *, int64_t, int64_t, __fp16 *, int64_t, int64_t, const void *, unsigned int, const __fp16, const __fp16);
 
-class a64_fp16_nhwc_3x3_s1_output4x4_mla_depthfirst : public IDepthwiseDepthfirstStrategy
+class a64_fp16_nhwc_3x3_s1_output4x4_mla_depthfirst : public DepthwiseDepthfirstStrategy<__fp16, __fp16, __fp16, __fp16>
 {
   private:
-  typedef void (*indirect_kern_type)(const __fp16 *const *const, __fp16 *const *const, const void *, unsigned int, const __fp16, const __fp16);
-  indirect_kern_type m_indirect_kernel = a64_fp16_nhwc_3x3_s1_output4x4_mla_depthfirst_indirect_impl;
-
-  typedef void (*direct_kern_type)(const unsigned int, const unsigned int, const __fp16 *, int64_t, int64_t, __fp16 *, int64_t, int64_t, const void *, unsigned int, const __fp16, const __fp16);
-  direct_kern_type m_direct_kernel = a64_fp16_nhwc_3x3_s1_output4x4_mla_depthfirst_direct_impl;
+  using Parent = DepthwiseDepthfirstStrategy<__fp16, __fp16, __fp16, __fp16>;
+  Parent::IndirectKernelType m_indirect_kernel = a64_fp16_nhwc_3x3_s1_output4x4_mla_depthfirst_indirect_impl;
+  Parent::DirectKernelType m_direct_kernel = a64_fp16_nhwc_3x3_s1_output4x4_mla_depthfirst_direct_impl;
 
   public:
-  typedef __fp16 return_type;
-
-  constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::None;
+  using return_type = __fp16;
+  constexpr static auto vl_type = arm_gemm::VLType::None;
 
   constexpr static unsigned int kernel_rows = 3;
   constexpr static unsigned int kernel_cols = 3;
@@ -59,60 +56,13 @@
   constexpr static unsigned int output_rows = 4;
   constexpr static unsigned int output_cols = 4;
 
-  constexpr static unsigned int input_rows = 6;
-  constexpr static unsigned int input_cols = 6;
-
-  a64_fp16_nhwc_3x3_s1_output4x4_mla_depthfirst(const CPUInfo *) {}
+  a64_fp16_nhwc_3x3_s1_output4x4_mla_depthfirst(const CPUInfo *)
+  : DepthwiseDepthfirstStrategy<__fp16, __fp16, __fp16, __fp16>(4, 3, 1) {}
 
   arm_gemm::VLType get_vl_type(void) const override { return vl_type; }
 
-  unsigned int get_kernel_rows(void) const override { return kernel_rows; }
-  unsigned int get_kernel_cols(void) const override { return kernel_cols; }
-
-  unsigned int get_stride_rows(void) const override { return stride_rows; }
-  unsigned int get_stride_cols(void) const override { return stride_cols; }
-
-  unsigned int get_output_rows(void) const override { return output_rows; }
-  unsigned int get_output_cols(void) const override { return output_cols; }
-
-  unsigned int get_input_rows(void) const override { return input_rows; }
-  unsigned int get_input_cols(void) const override { return input_cols; }
-
-  void indirect_kernel(
-    const void *const *const input_ptrs,
-    void *const *const outptrs,
-    const void *params,
-    unsigned int n_channels,
-    const void *activation_min,
-    const void *activation_max
-  ) const override
-  {
-    m_indirect_kernel(
-      reinterpret_cast<const __fp16 *const *>(input_ptrs),
-      reinterpret_cast<__fp16 *const *>(outptrs),
-      params, n_channels,
-      *static_cast<const __fp16 *>(activation_min),
-      *static_cast<const __fp16 *>(activation_max)
-    );
-  }
-
-  void direct_kernel(
-    const unsigned int n_tile_rows, const unsigned int n_tile_cols,
-    const void *inptr, int64_t ld_input_row, int64_t ld_input_col,
-    void *outptr, int64_t ld_output_row, int64_t ld_output_col,
-    const void *params, unsigned int n_channels,
-    const void *activation_min, const void *activation_max
-  ) const override
-  {
-    m_direct_kernel(
-      n_tile_rows, n_tile_cols,
-      static_cast<const __fp16 *>(inptr), ld_input_row, ld_input_col,
-      static_cast<__fp16 *>(outptr), ld_output_row, ld_output_col,
-      params, n_channels,
-      *static_cast<const __fp16 *>(activation_min),
-      *static_cast<const __fp16 *>(activation_max)
-    );
-  }
+  Parent::IndirectKernelType get_indirect_kernel() const override { return m_indirect_kernel; }
+  Parent::DirectKernelType get_direct_kernel() const override { return m_direct_kernel; }
 };
 
 }  // namespace depthwise
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp16_nhwc_3x3_s2_output2x2_mla_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp16_nhwc_3x3_s2_output2x2_mla_depthfirst.hpp
index d89ae0c..d52f480 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp16_nhwc_3x3_s2_output2x2_mla_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp16_nhwc_3x3_s2_output2x2_mla_depthfirst.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -36,19 +36,16 @@
 void a64_fp16_nhwc_3x3_s2_output2x2_mla_depthfirst_indirect_impl(const __fp16 *const *const, __fp16 *const *const, const void *, unsigned int, const __fp16, const __fp16);
 void a64_fp16_nhwc_3x3_s2_output2x2_mla_depthfirst_direct_impl(const unsigned int, const unsigned int, const __fp16 *, int64_t, int64_t, __fp16 *, int64_t, int64_t, const void *, unsigned int, const __fp16, const __fp16);
 
-class a64_fp16_nhwc_3x3_s2_output2x2_mla_depthfirst : public IDepthwiseDepthfirstStrategy
+class a64_fp16_nhwc_3x3_s2_output2x2_mla_depthfirst : public DepthwiseDepthfirstStrategy<__fp16, __fp16, __fp16, __fp16>
 {
   private:
-  typedef void (*indirect_kern_type)(const __fp16 *const *const, __fp16 *const *const, const void *, unsigned int, const __fp16, const __fp16);
-  indirect_kern_type m_indirect_kernel = a64_fp16_nhwc_3x3_s2_output2x2_mla_depthfirst_indirect_impl;
-
-  typedef void (*direct_kern_type)(const unsigned int, const unsigned int, const __fp16 *, int64_t, int64_t, __fp16 *, int64_t, int64_t, const void *, unsigned int, const __fp16, const __fp16);
-  direct_kern_type m_direct_kernel = a64_fp16_nhwc_3x3_s2_output2x2_mla_depthfirst_direct_impl;
+  using Parent = DepthwiseDepthfirstStrategy<__fp16, __fp16, __fp16, __fp16>;
+  Parent::IndirectKernelType m_indirect_kernel = a64_fp16_nhwc_3x3_s2_output2x2_mla_depthfirst_indirect_impl;
+  Parent::DirectKernelType m_direct_kernel = a64_fp16_nhwc_3x3_s2_output2x2_mla_depthfirst_direct_impl;
 
   public:
-  typedef __fp16 return_type;
-
-  constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::None;
+  using return_type = __fp16;
+  constexpr static auto vl_type = arm_gemm::VLType::None;
 
   constexpr static unsigned int kernel_rows = 3;
   constexpr static unsigned int kernel_cols = 3;
@@ -59,60 +56,13 @@
   constexpr static unsigned int output_rows = 2;
   constexpr static unsigned int output_cols = 2;
 
-  constexpr static unsigned int input_rows = 5;
-  constexpr static unsigned int input_cols = 5;
-
-  a64_fp16_nhwc_3x3_s2_output2x2_mla_depthfirst(const CPUInfo *) {}
+  a64_fp16_nhwc_3x3_s2_output2x2_mla_depthfirst(const CPUInfo *)
+  : DepthwiseDepthfirstStrategy<__fp16, __fp16, __fp16, __fp16>(2, 3, 2) {}
 
   arm_gemm::VLType get_vl_type(void) const override { return vl_type; }
 
-  unsigned int get_kernel_rows(void) const override { return kernel_rows; }
-  unsigned int get_kernel_cols(void) const override { return kernel_cols; }
-
-  unsigned int get_stride_rows(void) const override { return stride_rows; }
-  unsigned int get_stride_cols(void) const override { return stride_cols; }
-
-  unsigned int get_output_rows(void) const override { return output_rows; }
-  unsigned int get_output_cols(void) const override { return output_cols; }
-
-  unsigned int get_input_rows(void) const override { return input_rows; }
-  unsigned int get_input_cols(void) const override { return input_cols; }
-
-  void indirect_kernel(
-    const void *const *const input_ptrs,
-    void *const *const outptrs,
-    const void *params,
-    unsigned int n_channels,
-    const void *activation_min,
-    const void *activation_max
-  ) const override
-  {
-    m_indirect_kernel(
-      reinterpret_cast<const __fp16 *const *>(input_ptrs),
-      reinterpret_cast<__fp16 *const *>(outptrs),
-      params, n_channels,
-      *static_cast<const __fp16 *>(activation_min),
-      *static_cast<const __fp16 *>(activation_max)
-    );
-  }
-
-  void direct_kernel(
-    const unsigned int n_tile_rows, const unsigned int n_tile_cols,
-    const void *inptr, int64_t ld_input_row, int64_t ld_input_col,
-    void *outptr, int64_t ld_output_row, int64_t ld_output_col,
-    const void *params, unsigned int n_channels,
-    const void *activation_min, const void *activation_max
-  ) const override
-  {
-    m_direct_kernel(
-      n_tile_rows, n_tile_cols,
-      static_cast<const __fp16 *>(inptr), ld_input_row, ld_input_col,
-      static_cast<__fp16 *>(outptr), ld_output_row, ld_output_col,
-      params, n_channels,
-      *static_cast<const __fp16 *>(activation_min),
-      *static_cast<const __fp16 *>(activation_max)
-    );
-  }
+  Parent::IndirectKernelType get_indirect_kernel() const override { return m_indirect_kernel; }
+  Parent::DirectKernelType get_direct_kernel() const override { return m_direct_kernel; }
 };
 
 }  // namespace depthwise
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp16_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp16_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp
index 6b5f91f..81a608e 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp16_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp16_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -36,19 +36,16 @@
 void a64_fp16_nhwc_5x5_s1_output2x2_mla_depthfirst_indirect_impl(const __fp16 *const *const, __fp16 *const *const, const void *, unsigned int, const __fp16, const __fp16);
 void a64_fp16_nhwc_5x5_s1_output2x2_mla_depthfirst_direct_impl(const unsigned int, const unsigned int, const __fp16 *, int64_t, int64_t, __fp16 *, int64_t, int64_t, const void *, unsigned int, const __fp16, const __fp16);
 
-class a64_fp16_nhwc_5x5_s1_output2x2_mla_depthfirst : public IDepthwiseDepthfirstStrategy
+class a64_fp16_nhwc_5x5_s1_output2x2_mla_depthfirst : public DepthwiseDepthfirstStrategy<__fp16, __fp16, __fp16, __fp16>
 {
   private:
-  typedef void (*indirect_kern_type)(const __fp16 *const *const, __fp16 *const *const, const void *, unsigned int, const __fp16, const __fp16);
-  indirect_kern_type m_indirect_kernel = a64_fp16_nhwc_5x5_s1_output2x2_mla_depthfirst_indirect_impl;
-
-  typedef void (*direct_kern_type)(const unsigned int, const unsigned int, const __fp16 *, int64_t, int64_t, __fp16 *, int64_t, int64_t, const void *, unsigned int, const __fp16, const __fp16);
-  direct_kern_type m_direct_kernel = a64_fp16_nhwc_5x5_s1_output2x2_mla_depthfirst_direct_impl;
+  using Parent = DepthwiseDepthfirstStrategy<__fp16, __fp16, __fp16, __fp16>;
+  Parent::IndirectKernelType m_indirect_kernel = a64_fp16_nhwc_5x5_s1_output2x2_mla_depthfirst_indirect_impl;
+  Parent::DirectKernelType m_direct_kernel = a64_fp16_nhwc_5x5_s1_output2x2_mla_depthfirst_direct_impl;
 
   public:
-  typedef __fp16 return_type;
-
-  constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::None;
+  using return_type = __fp16;
+  constexpr static auto vl_type = arm_gemm::VLType::None;
 
   constexpr static unsigned int kernel_rows = 5;
   constexpr static unsigned int kernel_cols = 5;
@@ -59,60 +56,13 @@
   constexpr static unsigned int output_rows = 2;
   constexpr static unsigned int output_cols = 2;
 
-  constexpr static unsigned int input_rows = 6;
-  constexpr static unsigned int input_cols = 6;
-
-  a64_fp16_nhwc_5x5_s1_output2x2_mla_depthfirst(const CPUInfo *) {}
+  a64_fp16_nhwc_5x5_s1_output2x2_mla_depthfirst(const CPUInfo *)
+  : DepthwiseDepthfirstStrategy<__fp16, __fp16, __fp16, __fp16>(2, 5, 1) {}
 
   arm_gemm::VLType get_vl_type(void) const override { return vl_type; }
 
-  unsigned int get_kernel_rows(void) const override { return kernel_rows; }
-  unsigned int get_kernel_cols(void) const override { return kernel_cols; }
-
-  unsigned int get_stride_rows(void) const override { return stride_rows; }
-  unsigned int get_stride_cols(void) const override { return stride_cols; }
-
-  unsigned int get_output_rows(void) const override { return output_rows; }
-  unsigned int get_output_cols(void) const override { return output_cols; }
-
-  unsigned int get_input_rows(void) const override { return input_rows; }
-  unsigned int get_input_cols(void) const override { return input_cols; }
-
-  void indirect_kernel(
-    const void *const *const input_ptrs,
-    void *const *const outptrs,
-    const void *params,
-    unsigned int n_channels,
-    const void *activation_min,
-    const void *activation_max
-  ) const override
-  {
-    m_indirect_kernel(
-      reinterpret_cast<const __fp16 *const *>(input_ptrs),
-      reinterpret_cast<__fp16 *const *>(outptrs),
-      params, n_channels,
-      *static_cast<const __fp16 *>(activation_min),
-      *static_cast<const __fp16 *>(activation_max)
-    );
-  }
-
-  void direct_kernel(
-    const unsigned int n_tile_rows, const unsigned int n_tile_cols,
-    const void *inptr, int64_t ld_input_row, int64_t ld_input_col,
-    void *outptr, int64_t ld_output_row, int64_t ld_output_col,
-    const void *params, unsigned int n_channels,
-    const void *activation_min, const void *activation_max
-  ) const override
-  {
-    m_direct_kernel(
-      n_tile_rows, n_tile_cols,
-      static_cast<const __fp16 *>(inptr), ld_input_row, ld_input_col,
-      static_cast<__fp16 *>(outptr), ld_output_row, ld_output_col,
-      params, n_channels,
-      *static_cast<const __fp16 *>(activation_min),
-      *static_cast<const __fp16 *>(activation_max)
-    );
-  }
+  Parent::IndirectKernelType get_indirect_kernel() const override { return m_indirect_kernel; }
+  Parent::DirectKernelType get_direct_kernel() const override { return m_direct_kernel; }
 };
 
 }  // namespace depthwise
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp16_nhwc_generic_output9_mla_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp16_nhwc_generic_output9_mla_depthfirst.hpp
index 3468b70..1ccd340 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp16_nhwc_generic_output9_mla_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp16_nhwc_generic_output9_mla_depthfirst.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -28,32 +28,24 @@
 
 #pragma once
 
-#if defined(__ARM_FP16_ARGS) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
+#if defined(__aarch64__) && defined(__ARM_FP16_ARGS) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
 
 namespace arm_conv {
 namespace depthwise {
 
 void a64_fp16_nhwc_generic_output9_mla_depthfirst_impl(const __fp16 *const *const, __fp16 *const *const, const void *, const void *, const unsigned int, const unsigned int, const __fp16, const __fp16);
 
-struct a64_fp16_nhwc_generic_output9_mla_depthfirst
+class a64_fp16_nhwc_generic_output9_mla_depthfirst : public GenericDepthfirstKernelStrategy<__fp16, __fp16, __fp16, __fp16>
 {
-  typedef __fp16 bias_type;
-  typedef __fp16 input_type;
-  typedef __fp16 weight_type;
-  typedef __fp16 return_type;
+  KernelType kernel = a64_fp16_nhwc_generic_output9_mla_depthfirst_impl;
 
-  typedef void (*kern_type)(const __fp16 *const *const, __fp16 *const *const, const void *, const void *, const unsigned int, const unsigned int, const __fp16, const __fp16);
+  public:
+  a64_fp16_nhwc_generic_output9_mla_depthfirst(const CPUInfo *) : GenericDepthfirstKernelStrategy<__fp16, __fp16, __fp16, __fp16>(9, arm_gemm::VLType::None) {}
 
-  constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::None;
-
-  constexpr static unsigned int n_output_points = 9;
-
-  kern_type kernel = a64_fp16_nhwc_generic_output9_mla_depthfirst_impl;
-
-  a64_fp16_nhwc_generic_output9_mla_depthfirst(const CPUInfo *) {}
+  virtual  KernelType get_kernel() const override { return kernel; }
 };
 
 }  // namespace depthwise
 }  // namespace arm_conv
 
-#endif  // defined(__ARM_FP16_ARGS) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
+#endif  // defined(__aarch64__) && defined(__ARM_FP16_ARGS) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp16_nhwc_generic_output9_mla_depthfirst/generic.cpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp16_nhwc_generic_output9_mla_depthfirst/generic.cpp
index 8ac79f8..423ee41 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp16_nhwc_generic_output9_mla_depthfirst/generic.cpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp16_nhwc_generic_output9_mla_depthfirst/generic.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -25,7 +25,7 @@
 #include <cstddef>
 #include <cstdint>
 
-#if defined(__ARM_FP16_ARGS) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
+#if defined(__aarch64__) && defined(__ARM_FP16_ARGS) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
 
 namespace arm_conv {
 namespace depthwise {
@@ -524,4 +524,4 @@
 }  // namespace depthwise
 }  // namespace arm_conv
 
-#endif  // defined(__ARM_FP16_ARGS) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
+#endif  // defined(__aarch64__) && defined(__ARM_FP16_ARGS) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp16_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp16_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst.hpp
index a02a2b2..8fcbce2 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp16_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp16_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -28,35 +28,25 @@
 
 #pragma once
 
-#if defined(__ARM_FP16_ARGS) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
+#if defined(__aarch64__) && defined(__ARM_FP16_ARGS) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
 
 namespace arm_conv {
 namespace depthwise {
 
 void a64_fp16_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst_impl(const __fp16 *const *const, __fp16 *const *const, const __fp16 *, const __fp16 *, const unsigned int, const unsigned int, const __fp16, const __fp16);
 
-struct a64_fp16_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst
+struct a64_fp16_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst : GenericDepthfirstMultiplierKernelStrategy<__fp16, __fp16, __fp16, __fp16>
 {
-  typedef __fp16 bias_type;
-  typedef __fp16 input_type;
-  typedef __fp16 weight_type;
-  typedef __fp16 return_type;
-
-  typedef void (*kern_type)(const __fp16 *const *const, __fp16 *const *const, const __fp16 *, const __fp16 *, const unsigned int, const unsigned int, const __fp16, const __fp16);
-
-  constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::None;
-
-  constexpr static unsigned int output_rows(void) { return 2; };
-  constexpr static unsigned int output_cols(void) { return 8; };
-
-  constexpr static unsigned int output_col_regs(void) { return 1; };
-
-  kern_type kernel = a64_fp16_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst_impl;
-
-  a64_fp16_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst(const CPUInfo *) {}
+  using Parent = GenericDepthfirstMultiplierKernelStrategy<__fp16, __fp16, __fp16, __fp16>;
+  a64_fp16_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst(const CPUInfo *)
+  : Parent(2, 8, arm_gemm::VLType::None)
+  {
+  }
+  Parent::KernelType kernel = a64_fp16_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst_impl;
+  Parent::KernelType get_kernel(void) const override { return kernel; }
 };
 
 }  // namespace depthwise
 }  // namespace arm_conv
 
-#endif  // defined(__ARM_FP16_ARGS) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
+#endif  // defined(__aarch64__) && defined(__ARM_FP16_ARGS) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp16_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst/generic.cpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp16_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst/generic.cpp
index 7ed7c52..d9fc140 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp16_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst/generic.cpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp16_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst/generic.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -25,7 +25,7 @@
 #include <cstddef>
 #include <cstdint>
 
-#if defined(__ARM_FP16_ARGS) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
+#if defined(__aarch64__) && defined(__ARM_FP16_ARGS) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
 
 namespace arm_conv {
 namespace depthwise {
@@ -1046,4 +1046,4 @@
 }  // namespace depthwise
 }  // namespace arm_conv
 
-#endif  // defined(__ARM_FP16_ARGS) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
+#endif  // defined(__aarch64__) && defined(__ARM_FP16_ARGS) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp32_nhwc_3x3_s1_output2x2_mla_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp32_nhwc_3x3_s1_output2x2_mla_depthfirst.hpp
index a888eb5..420e953 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp32_nhwc_3x3_s1_output2x2_mla_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp32_nhwc_3x3_s1_output2x2_mla_depthfirst.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -28,7 +28,7 @@
 
 #pragma once
 
-#if __aarch64__
+#if defined(__aarch64__)
 
 namespace arm_conv {
 namespace depthwise {
@@ -36,19 +36,16 @@
 void a64_fp32_nhwc_3x3_s1_output2x2_mla_depthfirst_indirect_impl(const float *const *const, float *const *const, const void *, unsigned int, const float, const float);
 void a64_fp32_nhwc_3x3_s1_output2x2_mla_depthfirst_direct_impl(const unsigned int, const unsigned int, const float *, int64_t, int64_t, float *, int64_t, int64_t, const void *, unsigned int, const float, const float);
 
-class a64_fp32_nhwc_3x3_s1_output2x2_mla_depthfirst : public IDepthwiseDepthfirstStrategy
+class a64_fp32_nhwc_3x3_s1_output2x2_mla_depthfirst : public DepthwiseDepthfirstStrategy<float, float, float, float>
 {
   private:
-  typedef void (*indirect_kern_type)(const float *const *const, float *const *const, const void *, unsigned int, const float, const float);
-  indirect_kern_type m_indirect_kernel = a64_fp32_nhwc_3x3_s1_output2x2_mla_depthfirst_indirect_impl;
-
-  typedef void (*direct_kern_type)(const unsigned int, const unsigned int, const float *, int64_t, int64_t, float *, int64_t, int64_t, const void *, unsigned int, const float, const float);
-  direct_kern_type m_direct_kernel = a64_fp32_nhwc_3x3_s1_output2x2_mla_depthfirst_direct_impl;
+  using Parent = DepthwiseDepthfirstStrategy<float, float, float, float>;
+  Parent::IndirectKernelType m_indirect_kernel = a64_fp32_nhwc_3x3_s1_output2x2_mla_depthfirst_indirect_impl;
+  Parent::DirectKernelType m_direct_kernel = a64_fp32_nhwc_3x3_s1_output2x2_mla_depthfirst_direct_impl;
 
   public:
-  typedef float return_type;
-
-  constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::None;
+  using return_type = float;
+  constexpr static auto vl_type = arm_gemm::VLType::None;
 
   constexpr static unsigned int kernel_rows = 3;
   constexpr static unsigned int kernel_cols = 3;
@@ -59,63 +56,16 @@
   constexpr static unsigned int output_rows = 2;
   constexpr static unsigned int output_cols = 2;
 
-  constexpr static unsigned int input_rows = 4;
-  constexpr static unsigned int input_cols = 4;
-
-  a64_fp32_nhwc_3x3_s1_output2x2_mla_depthfirst(const CPUInfo *) {}
+  a64_fp32_nhwc_3x3_s1_output2x2_mla_depthfirst(const CPUInfo *)
+  : DepthwiseDepthfirstStrategy<float, float, float, float>(2, 3, 1) {}
 
   arm_gemm::VLType get_vl_type(void) const override { return vl_type; }
 
-  unsigned int get_kernel_rows(void) const override { return kernel_rows; }
-  unsigned int get_kernel_cols(void) const override { return kernel_cols; }
-
-  unsigned int get_stride_rows(void) const override { return stride_rows; }
-  unsigned int get_stride_cols(void) const override { return stride_cols; }
-
-  unsigned int get_output_rows(void) const override { return output_rows; }
-  unsigned int get_output_cols(void) const override { return output_cols; }
-
-  unsigned int get_input_rows(void) const override { return input_rows; }
-  unsigned int get_input_cols(void) const override { return input_cols; }
-
-  void indirect_kernel(
-    const void *const *const input_ptrs,
-    void *const *const outptrs,
-    const void *params,
-    unsigned int n_channels,
-    const void *activation_min,
-    const void *activation_max
-  ) const override
-  {
-    m_indirect_kernel(
-      reinterpret_cast<const float *const *>(input_ptrs),
-      reinterpret_cast<float *const *>(outptrs),
-      params, n_channels,
-      *static_cast<const float *>(activation_min),
-      *static_cast<const float *>(activation_max)
-    );
-  }
-
-  void direct_kernel(
-    const unsigned int n_tile_rows, const unsigned int n_tile_cols,
-    const void *inptr, int64_t ld_input_row, int64_t ld_input_col,
-    void *outptr, int64_t ld_output_row, int64_t ld_output_col,
-    const void *params, unsigned int n_channels,
-    const void *activation_min, const void *activation_max
-  ) const override
-  {
-    m_direct_kernel(
-      n_tile_rows, n_tile_cols,
-      static_cast<const float *>(inptr), ld_input_row, ld_input_col,
-      static_cast<float *>(outptr), ld_output_row, ld_output_col,
-      params, n_channels,
-      *static_cast<const float *>(activation_min),
-      *static_cast<const float *>(activation_max)
-    );
-  }
+  Parent::IndirectKernelType get_indirect_kernel() const override { return m_indirect_kernel; }
+  Parent::DirectKernelType get_direct_kernel() const override { return m_direct_kernel; }
 };
 
 }  // namespace depthwise
 }  // namespace arm_conv
 
-#endif  // __aarch64__
+#endif  // defined(__aarch64__)
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp32_nhwc_3x3_s1_output3x3_mla_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp32_nhwc_3x3_s1_output3x3_mla_depthfirst.hpp
index 01bb06a..0e9a3ba 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp32_nhwc_3x3_s1_output3x3_mla_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp32_nhwc_3x3_s1_output3x3_mla_depthfirst.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -28,7 +28,7 @@
 
 #pragma once
 
-#if __aarch64__
+#if defined(__aarch64__)
 
 namespace arm_conv {
 namespace depthwise {
@@ -36,19 +36,16 @@
 void a64_fp32_nhwc_3x3_s1_output3x3_mla_depthfirst_indirect_impl(const float *const *const, float *const *const, const void *, unsigned int, const float, const float);
 void a64_fp32_nhwc_3x3_s1_output3x3_mla_depthfirst_direct_impl(const unsigned int, const unsigned int, const float *, int64_t, int64_t, float *, int64_t, int64_t, const void *, unsigned int, const float, const float);
 
-class a64_fp32_nhwc_3x3_s1_output3x3_mla_depthfirst : public IDepthwiseDepthfirstStrategy
+class a64_fp32_nhwc_3x3_s1_output3x3_mla_depthfirst : public DepthwiseDepthfirstStrategy<float, float, float, float>
 {
   private:
-  typedef void (*indirect_kern_type)(const float *const *const, float *const *const, const void *, unsigned int, const float, const float);
-  indirect_kern_type m_indirect_kernel = a64_fp32_nhwc_3x3_s1_output3x3_mla_depthfirst_indirect_impl;
-
-  typedef void (*direct_kern_type)(const unsigned int, const unsigned int, const float *, int64_t, int64_t, float *, int64_t, int64_t, const void *, unsigned int, const float, const float);
-  direct_kern_type m_direct_kernel = a64_fp32_nhwc_3x3_s1_output3x3_mla_depthfirst_direct_impl;
+  using Parent = DepthwiseDepthfirstStrategy<float, float, float, float>;
+  Parent::IndirectKernelType m_indirect_kernel = a64_fp32_nhwc_3x3_s1_output3x3_mla_depthfirst_indirect_impl;
+  Parent::DirectKernelType m_direct_kernel = a64_fp32_nhwc_3x3_s1_output3x3_mla_depthfirst_direct_impl;
 
   public:
-  typedef float return_type;
-
-  constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::None;
+  using return_type = float;
+  constexpr static auto vl_type = arm_gemm::VLType::None;
 
   constexpr static unsigned int kernel_rows = 3;
   constexpr static unsigned int kernel_cols = 3;
@@ -59,63 +56,16 @@
   constexpr static unsigned int output_rows = 3;
   constexpr static unsigned int output_cols = 3;
 
-  constexpr static unsigned int input_rows = 5;
-  constexpr static unsigned int input_cols = 5;
-
-  a64_fp32_nhwc_3x3_s1_output3x3_mla_depthfirst(const CPUInfo *) {}
+  a64_fp32_nhwc_3x3_s1_output3x3_mla_depthfirst(const CPUInfo *)
+  : DepthwiseDepthfirstStrategy<float, float, float, float>(3, 3, 1) {}
 
   arm_gemm::VLType get_vl_type(void) const override { return vl_type; }
 
-  unsigned int get_kernel_rows(void) const override { return kernel_rows; }
-  unsigned int get_kernel_cols(void) const override { return kernel_cols; }
-
-  unsigned int get_stride_rows(void) const override { return stride_rows; }
-  unsigned int get_stride_cols(void) const override { return stride_cols; }
-
-  unsigned int get_output_rows(void) const override { return output_rows; }
-  unsigned int get_output_cols(void) const override { return output_cols; }
-
-  unsigned int get_input_rows(void) const override { return input_rows; }
-  unsigned int get_input_cols(void) const override { return input_cols; }
-
-  void indirect_kernel(
-    const void *const *const input_ptrs,
-    void *const *const outptrs,
-    const void *params,
-    unsigned int n_channels,
-    const void *activation_min,
-    const void *activation_max
-  ) const override
-  {
-    m_indirect_kernel(
-      reinterpret_cast<const float *const *>(input_ptrs),
-      reinterpret_cast<float *const *>(outptrs),
-      params, n_channels,
-      *static_cast<const float *>(activation_min),
-      *static_cast<const float *>(activation_max)
-    );
-  }
-
-  void direct_kernel(
-    const unsigned int n_tile_rows, const unsigned int n_tile_cols,
-    const void *inptr, int64_t ld_input_row, int64_t ld_input_col,
-    void *outptr, int64_t ld_output_row, int64_t ld_output_col,
-    const void *params, unsigned int n_channels,
-    const void *activation_min, const void *activation_max
-  ) const override
-  {
-    m_direct_kernel(
-      n_tile_rows, n_tile_cols,
-      static_cast<const float *>(inptr), ld_input_row, ld_input_col,
-      static_cast<float *>(outptr), ld_output_row, ld_output_col,
-      params, n_channels,
-      *static_cast<const float *>(activation_min),
-      *static_cast<const float *>(activation_max)
-    );
-  }
+  Parent::IndirectKernelType get_indirect_kernel() const override { return m_indirect_kernel; }
+  Parent::DirectKernelType get_direct_kernel() const override { return m_direct_kernel; }
 };
 
 }  // namespace depthwise
 }  // namespace arm_conv
 
-#endif  // __aarch64__
+#endif  // defined(__aarch64__)
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp32_nhwc_3x3_s1_output4x4_mla_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp32_nhwc_3x3_s1_output4x4_mla_depthfirst.hpp
index 17084b5..6c897d6 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp32_nhwc_3x3_s1_output4x4_mla_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp32_nhwc_3x3_s1_output4x4_mla_depthfirst.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -28,7 +28,7 @@
 
 #pragma once
 
-#if __aarch64__
+#if defined(__aarch64__)
 
 namespace arm_conv {
 namespace depthwise {
@@ -36,19 +36,16 @@
 void a64_fp32_nhwc_3x3_s1_output4x4_mla_depthfirst_indirect_impl(const float *const *const, float *const *const, const void *, unsigned int, const float, const float);
 void a64_fp32_nhwc_3x3_s1_output4x4_mla_depthfirst_direct_impl(const unsigned int, const unsigned int, const float *, int64_t, int64_t, float *, int64_t, int64_t, const void *, unsigned int, const float, const float);
 
-class a64_fp32_nhwc_3x3_s1_output4x4_mla_depthfirst : public IDepthwiseDepthfirstStrategy
+class a64_fp32_nhwc_3x3_s1_output4x4_mla_depthfirst : public DepthwiseDepthfirstStrategy<float, float, float, float>
 {
   private:
-  typedef void (*indirect_kern_type)(const float *const *const, float *const *const, const void *, unsigned int, const float, const float);
-  indirect_kern_type m_indirect_kernel = a64_fp32_nhwc_3x3_s1_output4x4_mla_depthfirst_indirect_impl;
-
-  typedef void (*direct_kern_type)(const unsigned int, const unsigned int, const float *, int64_t, int64_t, float *, int64_t, int64_t, const void *, unsigned int, const float, const float);
-  direct_kern_type m_direct_kernel = a64_fp32_nhwc_3x3_s1_output4x4_mla_depthfirst_direct_impl;
+  using Parent = DepthwiseDepthfirstStrategy<float, float, float, float>;
+  Parent::IndirectKernelType m_indirect_kernel = a64_fp32_nhwc_3x3_s1_output4x4_mla_depthfirst_indirect_impl;
+  Parent::DirectKernelType m_direct_kernel = a64_fp32_nhwc_3x3_s1_output4x4_mla_depthfirst_direct_impl;
 
   public:
-  typedef float return_type;
-
-  constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::None;
+  using return_type = float;
+  constexpr static auto vl_type = arm_gemm::VLType::None;
 
   constexpr static unsigned int kernel_rows = 3;
   constexpr static unsigned int kernel_cols = 3;
@@ -59,63 +56,16 @@
   constexpr static unsigned int output_rows = 4;
   constexpr static unsigned int output_cols = 4;
 
-  constexpr static unsigned int input_rows = 6;
-  constexpr static unsigned int input_cols = 6;
-
-  a64_fp32_nhwc_3x3_s1_output4x4_mla_depthfirst(const CPUInfo *) {}
+  a64_fp32_nhwc_3x3_s1_output4x4_mla_depthfirst(const CPUInfo *)
+  : DepthwiseDepthfirstStrategy<float, float, float, float>(4, 3, 1) {}
 
   arm_gemm::VLType get_vl_type(void) const override { return vl_type; }
 
-  unsigned int get_kernel_rows(void) const override { return kernel_rows; }
-  unsigned int get_kernel_cols(void) const override { return kernel_cols; }
-
-  unsigned int get_stride_rows(void) const override { return stride_rows; }
-  unsigned int get_stride_cols(void) const override { return stride_cols; }
-
-  unsigned int get_output_rows(void) const override { return output_rows; }
-  unsigned int get_output_cols(void) const override { return output_cols; }
-
-  unsigned int get_input_rows(void) const override { return input_rows; }
-  unsigned int get_input_cols(void) const override { return input_cols; }
-
-  void indirect_kernel(
-    const void *const *const input_ptrs,
-    void *const *const outptrs,
-    const void *params,
-    unsigned int n_channels,
-    const void *activation_min,
-    const void *activation_max
-  ) const override
-  {
-    m_indirect_kernel(
-      reinterpret_cast<const float *const *>(input_ptrs),
-      reinterpret_cast<float *const *>(outptrs),
-      params, n_channels,
-      *static_cast<const float *>(activation_min),
-      *static_cast<const float *>(activation_max)
-    );
-  }
-
-  void direct_kernel(
-    const unsigned int n_tile_rows, const unsigned int n_tile_cols,
-    const void *inptr, int64_t ld_input_row, int64_t ld_input_col,
-    void *outptr, int64_t ld_output_row, int64_t ld_output_col,
-    const void *params, unsigned int n_channels,
-    const void *activation_min, const void *activation_max
-  ) const override
-  {
-    m_direct_kernel(
-      n_tile_rows, n_tile_cols,
-      static_cast<const float *>(inptr), ld_input_row, ld_input_col,
-      static_cast<float *>(outptr), ld_output_row, ld_output_col,
-      params, n_channels,
-      *static_cast<const float *>(activation_min),
-      *static_cast<const float *>(activation_max)
-    );
-  }
+  Parent::IndirectKernelType get_indirect_kernel() const override { return m_indirect_kernel; }
+  Parent::DirectKernelType get_direct_kernel() const override { return m_direct_kernel; }
 };
 
 }  // namespace depthwise
 }  // namespace arm_conv
 
-#endif  // __aarch64__
+#endif  // defined(__aarch64__)
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp32_nhwc_3x3_s2_output2x2_mla_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp32_nhwc_3x3_s2_output2x2_mla_depthfirst.hpp
index f23862b..ff521fb 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp32_nhwc_3x3_s2_output2x2_mla_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp32_nhwc_3x3_s2_output2x2_mla_depthfirst.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -28,7 +28,7 @@
 
 #pragma once
 
-#if __aarch64__
+#if defined(__aarch64__)
 
 namespace arm_conv {
 namespace depthwise {
@@ -36,19 +36,16 @@
 void a64_fp32_nhwc_3x3_s2_output2x2_mla_depthfirst_indirect_impl(const float *const *const, float *const *const, const void *, unsigned int, const float, const float);
 void a64_fp32_nhwc_3x3_s2_output2x2_mla_depthfirst_direct_impl(const unsigned int, const unsigned int, const float *, int64_t, int64_t, float *, int64_t, int64_t, const void *, unsigned int, const float, const float);
 
-class a64_fp32_nhwc_3x3_s2_output2x2_mla_depthfirst : public IDepthwiseDepthfirstStrategy
+class a64_fp32_nhwc_3x3_s2_output2x2_mla_depthfirst : public DepthwiseDepthfirstStrategy<float, float, float, float>
 {
   private:
-  typedef void (*indirect_kern_type)(const float *const *const, float *const *const, const void *, unsigned int, const float, const float);
-  indirect_kern_type m_indirect_kernel = a64_fp32_nhwc_3x3_s2_output2x2_mla_depthfirst_indirect_impl;
-
-  typedef void (*direct_kern_type)(const unsigned int, const unsigned int, const float *, int64_t, int64_t, float *, int64_t, int64_t, const void *, unsigned int, const float, const float);
-  direct_kern_type m_direct_kernel = a64_fp32_nhwc_3x3_s2_output2x2_mla_depthfirst_direct_impl;
+  using Parent = DepthwiseDepthfirstStrategy<float, float, float, float>;
+  Parent::IndirectKernelType m_indirect_kernel = a64_fp32_nhwc_3x3_s2_output2x2_mla_depthfirst_indirect_impl;
+  Parent::DirectKernelType m_direct_kernel = a64_fp32_nhwc_3x3_s2_output2x2_mla_depthfirst_direct_impl;
 
   public:
-  typedef float return_type;
-
-  constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::None;
+  using return_type = float;
+  constexpr static auto vl_type = arm_gemm::VLType::None;
 
   constexpr static unsigned int kernel_rows = 3;
   constexpr static unsigned int kernel_cols = 3;
@@ -59,63 +56,16 @@
   constexpr static unsigned int output_rows = 2;
   constexpr static unsigned int output_cols = 2;
 
-  constexpr static unsigned int input_rows = 5;
-  constexpr static unsigned int input_cols = 5;
-
-  a64_fp32_nhwc_3x3_s2_output2x2_mla_depthfirst(const CPUInfo *) {}
+  a64_fp32_nhwc_3x3_s2_output2x2_mla_depthfirst(const CPUInfo *)
+  : DepthwiseDepthfirstStrategy<float, float, float, float>(2, 3, 2) {}
 
   arm_gemm::VLType get_vl_type(void) const override { return vl_type; }
 
-  unsigned int get_kernel_rows(void) const override { return kernel_rows; }
-  unsigned int get_kernel_cols(void) const override { return kernel_cols; }
-
-  unsigned int get_stride_rows(void) const override { return stride_rows; }
-  unsigned int get_stride_cols(void) const override { return stride_cols; }
-
-  unsigned int get_output_rows(void) const override { return output_rows; }
-  unsigned int get_output_cols(void) const override { return output_cols; }
-
-  unsigned int get_input_rows(void) const override { return input_rows; }
-  unsigned int get_input_cols(void) const override { return input_cols; }
-
-  void indirect_kernel(
-    const void *const *const input_ptrs,
-    void *const *const outptrs,
-    const void *params,
-    unsigned int n_channels,
-    const void *activation_min,
-    const void *activation_max
-  ) const override
-  {
-    m_indirect_kernel(
-      reinterpret_cast<const float *const *>(input_ptrs),
-      reinterpret_cast<float *const *>(outptrs),
-      params, n_channels,
-      *static_cast<const float *>(activation_min),
-      *static_cast<const float *>(activation_max)
-    );
-  }
-
-  void direct_kernel(
-    const unsigned int n_tile_rows, const unsigned int n_tile_cols,
-    const void *inptr, int64_t ld_input_row, int64_t ld_input_col,
-    void *outptr, int64_t ld_output_row, int64_t ld_output_col,
-    const void *params, unsigned int n_channels,
-    const void *activation_min, const void *activation_max
-  ) const override
-  {
-    m_direct_kernel(
-      n_tile_rows, n_tile_cols,
-      static_cast<const float *>(inptr), ld_input_row, ld_input_col,
-      static_cast<float *>(outptr), ld_output_row, ld_output_col,
-      params, n_channels,
-      *static_cast<const float *>(activation_min),
-      *static_cast<const float *>(activation_max)
-    );
-  }
+  Parent::IndirectKernelType get_indirect_kernel() const override { return m_indirect_kernel; }
+  Parent::DirectKernelType get_direct_kernel() const override { return m_direct_kernel; }
 };
 
 }  // namespace depthwise
 }  // namespace arm_conv
 
-#endif  // __aarch64__
+#endif  // defined(__aarch64__)
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp
index e4bfbe6..c88a7d5 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -36,19 +36,16 @@
 void a64_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst_indirect_impl(const float *const *const, float *const *const, const void *, unsigned int, const float, const float);
 void a64_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst_direct_impl(const unsigned int, const unsigned int, const float *, int64_t, int64_t, float *, int64_t, int64_t, const void *, unsigned int, const float, const float);
 
-class a64_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst : public IDepthwiseDepthfirstStrategy
+class a64_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst : public DepthwiseDepthfirstStrategy<float, float, float, float>
 {
   private:
-  typedef void (*indirect_kern_type)(const float *const *const, float *const *const, const void *, unsigned int, const float, const float);
-  indirect_kern_type m_indirect_kernel = a64_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst_indirect_impl;
-
-  typedef void (*direct_kern_type)(const unsigned int, const unsigned int, const float *, int64_t, int64_t, float *, int64_t, int64_t, const void *, unsigned int, const float, const float);
-  direct_kern_type m_direct_kernel = a64_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst_direct_impl;
+  using Parent = DepthwiseDepthfirstStrategy<float, float, float, float>;
+  Parent::IndirectKernelType m_indirect_kernel = a64_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst_indirect_impl;
+  Parent::DirectKernelType m_direct_kernel = a64_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst_direct_impl;
 
   public:
-  typedef float return_type;
-
-  constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::None;
+  using return_type = float;
+  constexpr static auto vl_type = arm_gemm::VLType::None;
 
   constexpr static unsigned int kernel_rows = 5;
   constexpr static unsigned int kernel_cols = 5;
@@ -59,60 +56,13 @@
   constexpr static unsigned int output_rows = 2;
   constexpr static unsigned int output_cols = 2;
 
-  constexpr static unsigned int input_rows = 6;
-  constexpr static unsigned int input_cols = 6;
-
-  a64_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst(const CPUInfo *) {}
+  a64_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst(const CPUInfo *)
+  : DepthwiseDepthfirstStrategy<float, float, float, float>(2, 5, 1) {}
 
   arm_gemm::VLType get_vl_type(void) const override { return vl_type; }
 
-  unsigned int get_kernel_rows(void) const override { return kernel_rows; }
-  unsigned int get_kernel_cols(void) const override { return kernel_cols; }
-
-  unsigned int get_stride_rows(void) const override { return stride_rows; }
-  unsigned int get_stride_cols(void) const override { return stride_cols; }
-
-  unsigned int get_output_rows(void) const override { return output_rows; }
-  unsigned int get_output_cols(void) const override { return output_cols; }
-
-  unsigned int get_input_rows(void) const override { return input_rows; }
-  unsigned int get_input_cols(void) const override { return input_cols; }
-
-  void indirect_kernel(
-    const void *const *const input_ptrs,
-    void *const *const outptrs,
-    const void *params,
-    unsigned int n_channels,
-    const void *activation_min,
-    const void *activation_max
-  ) const override
-  {
-    m_indirect_kernel(
-      reinterpret_cast<const float *const *>(input_ptrs),
-      reinterpret_cast<float *const *>(outptrs),
-      params, n_channels,
-      *static_cast<const float *>(activation_min),
-      *static_cast<const float *>(activation_max)
-    );
-  }
-
-  void direct_kernel(
-    const unsigned int n_tile_rows, const unsigned int n_tile_cols,
-    const void *inptr, int64_t ld_input_row, int64_t ld_input_col,
-    void *outptr, int64_t ld_output_row, int64_t ld_output_col,
-    const void *params, unsigned int n_channels,
-    const void *activation_min, const void *activation_max
-  ) const override
-  {
-    m_direct_kernel(
-      n_tile_rows, n_tile_cols,
-      static_cast<const float *>(inptr), ld_input_row, ld_input_col,
-      static_cast<float *>(outptr), ld_output_row, ld_output_col,
-      params, n_channels,
-      *static_cast<const float *>(activation_min),
-      *static_cast<const float *>(activation_max)
-    );
-  }
+  Parent::IndirectKernelType get_indirect_kernel() const override { return m_indirect_kernel; }
+  Parent::DirectKernelType get_direct_kernel() const override { return m_direct_kernel; }
 };
 
 }  // namespace depthwise
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp32_nhwc_generic_output9_mla_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp32_nhwc_generic_output9_mla_depthfirst.hpp
index 0f6cecd..6fa02b7 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp32_nhwc_generic_output9_mla_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp32_nhwc_generic_output9_mla_depthfirst.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -28,28 +28,24 @@
 
 #pragma once
 
+#if defined(__aarch64__)
+
 namespace arm_conv {
 namespace depthwise {
 
 void a64_fp32_nhwc_generic_output9_mla_depthfirst_impl(const float *const *const, float *const *const, const void *, const void *, const unsigned int, const unsigned int, const float, const float);
 
-struct a64_fp32_nhwc_generic_output9_mla_depthfirst
+class a64_fp32_nhwc_generic_output9_mla_depthfirst : public GenericDepthfirstKernelStrategy<float, float, float, float>
 {
-  typedef float bias_type;
-  typedef float input_type;
-  typedef float weight_type;
-  typedef float return_type;
+  KernelType kernel = a64_fp32_nhwc_generic_output9_mla_depthfirst_impl;
 
-  typedef void (*kern_type)(const float *const *const, float *const *const, const void *, const void *, const unsigned int, const unsigned int, const float, const float);
+  public:
+  a64_fp32_nhwc_generic_output9_mla_depthfirst(const CPUInfo *) : GenericDepthfirstKernelStrategy<float, float, float, float>(9, arm_gemm::VLType::None) {}
 
-  constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::None;
-
-  constexpr static unsigned int n_output_points = 9;
-
-  kern_type kernel = a64_fp32_nhwc_generic_output9_mla_depthfirst_impl;
-
-  a64_fp32_nhwc_generic_output9_mla_depthfirst(const CPUInfo *) {}
+  KernelType get_kernel() const override { return kernel; }
 };
 
 }  // namespace depthwise
 }  // namespace arm_conv
+
+#endif // defined(__aarch64__)
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp32_packed_to_nhwc_3x3_s2_with_multiplier_output3x3_mla_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp32_packed_to_nhwc_3x3_s2_with_multiplier_output3x3_mla_depthfirst.hpp
index 60f5ddd..2ec0525 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp32_packed_to_nhwc_3x3_s2_with_multiplier_output3x3_mla_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp32_packed_to_nhwc_3x3_s2_with_multiplier_output3x3_mla_depthfirst.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -28,39 +28,34 @@
 
 #pragma once
 
+#if defined(__aarch64__)
+
 namespace arm_conv {
 namespace depthwise {
 
 void a64_fp32_packed_to_nhwc_3x3_s2_with_multiplier_output3x3_mla_depthfirst_impl(const float *const *const, float *const *const, const void *, const unsigned int, const float, const float);
 
-struct a64_fp32_packed_to_nhwc_3x3_s2_with_multiplier_output3x3_mla_depthfirst
+struct a64_fp32_packed_to_nhwc_3x3_s2_with_multiplier_output3x3_mla_depthfirst : DepthfirstMultiplierStrategy<float, float, float, float>
 {
-  typedef float bias_type;
-  typedef float input_type;
-  typedef float weight_type;
-  typedef float return_type;
-
-  typedef void (*kern_type)(const float *const *const, float *const *const, const void *, const unsigned int, const float, const float);
-
-  constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::None;
-
+  using Parent = DepthfirstMultiplierStrategy<float, float, float, float>;
   constexpr static unsigned int kernel_rows = 3;
   constexpr static unsigned int kernel_cols = 3;
 
   constexpr static unsigned int stride_rows = 2;
   constexpr static unsigned int stride_cols = 2;
 
-  constexpr static unsigned int output_rows = 3;
-  constexpr static unsigned int output_cols = 3;
+  a64_fp32_packed_to_nhwc_3x3_s2_with_multiplier_output3x3_mla_depthfirst(const CPUInfo *)
+  : Parent(3, 3, kernel_rows, kernel_cols, stride_rows, stride_cols)
+  {
+  }
 
-  constexpr static unsigned int input_rows = 7;
-  constexpr static unsigned int input_cols = 7;
-  constexpr static unsigned int input_col_quads = 2;
+  arm_gemm::VLType get_vl_type() const override { return arm_gemm::VLType::None; }
 
-  kern_type kernel = a64_fp32_packed_to_nhwc_3x3_s2_with_multiplier_output3x3_mla_depthfirst_impl;
-
-  a64_fp32_packed_to_nhwc_3x3_s2_with_multiplier_output3x3_mla_depthfirst(const CPUInfo *) {}
+  Parent::KernelType kernel = a64_fp32_packed_to_nhwc_3x3_s2_with_multiplier_output3x3_mla_depthfirst_impl;
+  Parent::KernelType get_kernel(void) const override { return kernel; }
 };
 
 }  // namespace depthwise
 }  // namespace arm_conv
+
+#endif // defined(__aarch64__)
\ No newline at end of file
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp32_packed_to_nhwc_5x5_s1_with_multiplier_output2x4_mla_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp32_packed_to_nhwc_5x5_s1_with_multiplier_output2x4_mla_depthfirst.hpp
index 92d6a75..5ae8dd3 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp32_packed_to_nhwc_5x5_s1_with_multiplier_output2x4_mla_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp32_packed_to_nhwc_5x5_s1_with_multiplier_output2x4_mla_depthfirst.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -28,39 +28,34 @@
 
 #pragma once
 
+#if defined(__aarch64__)
+
 namespace arm_conv {
 namespace depthwise {
 
 void a64_fp32_packed_to_nhwc_5x5_s1_with_multiplier_output2x4_mla_depthfirst_impl(const float *const *const, float *const *const, const void *, const unsigned int, const float, const float);
 
-struct a64_fp32_packed_to_nhwc_5x5_s1_with_multiplier_output2x4_mla_depthfirst
+struct a64_fp32_packed_to_nhwc_5x5_s1_with_multiplier_output2x4_mla_depthfirst : DepthfirstMultiplierStrategy<float, float, float, float>
 {
-  typedef float bias_type;
-  typedef float input_type;
-  typedef float weight_type;
-  typedef float return_type;
-
-  typedef void (*kern_type)(const float *const *const, float *const *const, const void *, const unsigned int, const float, const float);
-
-  constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::None;
-
+  using Parent = DepthfirstMultiplierStrategy<float, float, float, float>;
   constexpr static unsigned int kernel_rows = 5;
   constexpr static unsigned int kernel_cols = 5;
 
   constexpr static unsigned int stride_rows = 1;
   constexpr static unsigned int stride_cols = 1;
 
-  constexpr static unsigned int output_rows = 2;
-  constexpr static unsigned int output_cols = 4;
+  a64_fp32_packed_to_nhwc_5x5_s1_with_multiplier_output2x4_mla_depthfirst(const CPUInfo *)
+  : Parent(2, 4, kernel_rows, kernel_cols, stride_rows, stride_cols)
+  {
+  }
 
-  constexpr static unsigned int input_rows = 6;
-  constexpr static unsigned int input_cols = 8;
-  constexpr static unsigned int input_col_quads = 2;
+  arm_gemm::VLType get_vl_type() const override { return arm_gemm::VLType::None; }
 
-  kern_type kernel = a64_fp32_packed_to_nhwc_5x5_s1_with_multiplier_output2x4_mla_depthfirst_impl;
-
-  a64_fp32_packed_to_nhwc_5x5_s1_with_multiplier_output2x4_mla_depthfirst(const CPUInfo *) {}
+  Parent::KernelType kernel = a64_fp32_packed_to_nhwc_5x5_s1_with_multiplier_output2x4_mla_depthfirst_impl;
+  Parent::KernelType get_kernel(void) const override { return kernel; }
 };
 
 }  // namespace depthwise
 }  // namespace arm_conv
+
+#endif // defined(__aarch64__)
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp32_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp32_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst.hpp
index 2cc2f7c..d60e15e 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp32_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp32_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -28,31 +28,25 @@
 
 #pragma once
 
+#if defined(__aarch64__)
+
 namespace arm_conv {
 namespace depthwise {
 
 void a64_fp32_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst_impl(const float *const *const, float *const *const, const float *, const float *, const unsigned int, const unsigned int, const float, const float);
 
-struct a64_fp32_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst
+struct a64_fp32_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst : GenericDepthfirstMultiplierKernelStrategy<float, float, float, float>
 {
-  typedef float bias_type;
-  typedef float input_type;
-  typedef float weight_type;
-  typedef float return_type;
-
-  typedef void (*kern_type)(const float *const *const, float *const *const, const float *, const float *, const unsigned int, const unsigned int, const float, const float);
-
-  constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::None;
-
-  constexpr static unsigned int output_rows(void) { return 2; };
-  constexpr static unsigned int output_cols(void) { return 8; };
-
-  constexpr static unsigned int output_col_regs(void) { return 2; };
-
-  kern_type kernel = a64_fp32_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst_impl;
-
-  a64_fp32_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst(const CPUInfo *) {}
+  using Parent = GenericDepthfirstMultiplierKernelStrategy<float, float, float, float>;
+  a64_fp32_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst(const CPUInfo *)
+  : Parent(2, 8, arm_gemm::VLType::None)
+  {
+  }
+  Parent::KernelType kernel = a64_fp32_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst_impl;
+  Parent::KernelType get_kernel(void) const override { return kernel; }
 };
 
 }  // namespace depthwise
 }  // namespace arm_conv
+
+#endif  // defined(__aarch64__)
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_s8q_nhwc_3x3_s1_output2x2_dot_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_s8q_nhwc_3x3_s1_output2x2_dot_depthfirst.hpp
index c76cb99..62e4a82 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_s8q_nhwc_3x3_s1_output2x2_dot_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_s8q_nhwc_3x3_s1_output2x2_dot_depthfirst.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -34,39 +34,40 @@
 namespace arm_conv {
 namespace depthwise {
 
-void a64_s8q_nhwc_3x3_s1_output2x2_dot_depthfirst_impl(const int8_t *const *, int8_t *const *, const void *, uint64_t, const arm_gemm::Requantize32&);
+void a64_s8q_nhwc_3x3_s1_output2x2_dot_depthfirst_impl(unsigned int, const int8_t *const *, const int8_t *, const int32_t *, const arm_gemm::Requantize32&, const int32_t *, const int32_t *, int8_t *const *);
 
-struct a64_s8q_nhwc_3x3_s1_output2x2_dot_depthfirst
+class a64_s8q_nhwc_3x3_s1_output2x2_dot_depthfirst : public DepthwiseDepthfirstStrategy<int8_t, int8_t, int8_t, int32_t>
 {
-  typedef int32_t bias_type;
-  typedef int8_t input_type;
-  typedef int8_t weight_type;
-  typedef int8_t return_type;
+  using Parent = DepthwiseDepthfirstStrategy<int8_t, int8_t, int8_t, int32_t>;
 
-  constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::None;
-
-  typedef void (*kern_type)(const int8_t *const *, int8_t *const *, const void *, uint64_t, const arm_gemm::Requantize32&);
-  typedef void (*parameter_packing_fn)(unsigned int, void *, const int32_t *, const int8_t *, const arm_gemm::Requantize32 &, size_t, size_t);
-  typedef size_t (*parameter_sizing_fn)(const DepthwiseArgs &);
-
+  public:
   constexpr static unsigned int kernel_rows = 3;
   constexpr static unsigned int kernel_cols = 3;
 
   constexpr static unsigned int stride_rows = 1;
   constexpr static unsigned int stride_cols = 1;
 
-  constexpr static unsigned int output_rows = 2;
-  constexpr static unsigned int output_cols = 2;
+  a64_s8q_nhwc_3x3_s1_output2x2_dot_depthfirst(const CPUInfo *) : Parent(2, 2, 3, 3, 1, 1) {}
 
-  constexpr static unsigned int input_rows = 4;
-  constexpr static unsigned int input_cols = 4;
+  arm_gemm::VLType get_vl_type(void) const override { return arm_gemm::VLType::None; }
 
-  constexpr static parameter_packing_fn pack_parameters = interleave_a64_s8q_3x3_dot::pack_parameters;
-  constexpr static parameter_sizing_fn get_packed_size = interleave_a64_s8q_3x3_dot::get_packed_size;
+  Parent::KernelType kernel = a64_s8q_nhwc_3x3_s1_output2x2_dot_depthfirst_impl;
+  Parent::KernelType get_kernel(void) const override { return kernel; }
+  size_t get_storage_size(const DepthwiseArgs &args) const override
+  {
+    return interleave_a64_s8q_3x3_dot::get_packed_size(args);
+  }
 
-  kern_type kernel = a64_s8q_nhwc_3x3_s1_output2x2_dot_depthfirst_impl;
-
-  a64_s8q_nhwc_3x3_s1_output2x2_dot_depthfirst(const CPUInfo *) {}
+  void pack_parameters(
+    const DepthwiseArgs &args, void *buffer, const void *biases, const arm_gemm::Requantize32 &qp,
+    const void *weights, size_t ld_weight_col, size_t ld_weight_row
+  ) const override
+  {
+    interleave_a64_s8q_3x3_dot::pack_parameters(
+      args.input_channels, buffer, reinterpret_cast<const int32_t *>(biases),
+      reinterpret_cast<const int8_t *>(weights), qp, ld_weight_col, ld_weight_row
+    );
+  }
 };
 
 }  // namespace depthwise
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_s8q_nhwc_3x3_s1_output2x2_dot_depthfirst/generic.cpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_s8q_nhwc_3x3_s1_output2x2_dot_depthfirst/generic.cpp
index ed8cd48..f8245fc 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_s8q_nhwc_3x3_s1_output2x2_dot_depthfirst/generic.cpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_s8q_nhwc_3x3_s1_output2x2_dot_depthfirst/generic.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -30,7 +30,15 @@
 namespace arm_conv {
 namespace depthwise {
 
-void a64_s8q_nhwc_3x3_s1_output2x2_dot_depthfirst_impl(const int8_t *const *const inptrs, int8_t *const *const outptrs, const void *params, const uint64_t n_channels, const arm_gemm::Requantize32& qp)
+void a64_s8q_nhwc_3x3_s1_output2x2_dot_depthfirst_impl(
+  const unsigned int n_channels,
+  const int8_t *const *const inptrs,
+  const int8_t *params,
+  const int32_t *,  // Bias, should be wrapped into the parameters
+  const arm_gemm::Requantize32& qp,
+  const int32_t *, const int32_t *,  // Requant parameters, also wrapped
+  int8_t *const *const outptrs
+)
 {
   __asm__ __volatile__(
     "ldp x13, x12, [%x[inptrs], #0x0]\n"
@@ -1307,7 +1315,7 @@
     "34:"  // End
     "add SP, SP, #0x80\n"
     : [params] "+&r" (params)
-    : [inptrs] "r" (inptrs), [n_channels] "r" (n_channels), [offsetof_Requantize32_b_offset] "I" (offsetof(arm_gemm::Requantize32, b_offset)), [offsetof_Requantize32_c_offset] "I" (offsetof(arm_gemm::Requantize32, c_offset)), [offsetof_Requantize32_maxval] "I" (offsetof(arm_gemm::Requantize32, maxval)), [offsetof_Requantize32_minval] "I" (offsetof(arm_gemm::Requantize32, minval)), [outptrs] "r" (outptrs), [qp] "r" (&qp)
+    : [inptrs] "r" (inptrs), [n_channels] "r" ((long unsigned int) n_channels), [offsetof_Requantize32_b_offset] "I" (offsetof(arm_gemm::Requantize32, b_offset)), [offsetof_Requantize32_c_offset] "I" (offsetof(arm_gemm::Requantize32, c_offset)), [offsetof_Requantize32_maxval] "I" (offsetof(arm_gemm::Requantize32, maxval)), [offsetof_Requantize32_minval] "I" (offsetof(arm_gemm::Requantize32, minval)), [outptrs] "r" (outptrs), [qp] "r" (&qp)
     : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x11", "x12", "x13", "x19", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"
   );
 }
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_s8q_nhwc_3x3_s1_output2x2_mla_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_s8q_nhwc_3x3_s1_output2x2_mla_depthfirst.hpp
index 76c927a..c1baab4 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_s8q_nhwc_3x3_s1_output2x2_mla_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_s8q_nhwc_3x3_s1_output2x2_mla_depthfirst.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -36,37 +36,24 @@
 
 void a64_s8q_nhwc_3x3_s1_output2x2_mla_depthfirst_impl(unsigned int, const int8_t *const *, const int8_t *, const int32_t *, const arm_gemm::Requantize32 &, const int32_t *, const int32_t *, int8_t *const *);
 
-struct a64_s8q_nhwc_3x3_s1_output2x2_mla_depthfirst
+class a64_s8q_nhwc_3x3_s1_output2x2_mla_depthfirst : public DepthwiseDepthfirstStrategy<int8_t, int8_t, int8_t, int32_t>
 {
-  typedef int32_t bias_type;
-  typedef int8_t input_type;
-  typedef int8_t weight_type;
-  typedef int8_t return_type;
+  using Parent = DepthwiseDepthfirstStrategy<int8_t, int8_t, int8_t, int32_t>;
 
-  constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::None;
-
-  typedef void (*kern_type)(unsigned int, const int8_t *const *, const int8_t *, const int32_t *, const arm_gemm::Requantize32 &, const int32_t *, const int32_t *, int8_t *const *);
-  typedef void (*parameter_packing_fn)(unsigned int, void *, const int8_t *, size_t, size_t);
-  typedef size_t (*parameter_sizing_fn)(const DepthwiseArgs &);
-
+  public:
   constexpr static unsigned int kernel_rows = 3;
   constexpr static unsigned int kernel_cols = 3;
 
   constexpr static unsigned int stride_rows = 1;
   constexpr static unsigned int stride_cols = 1;
 
-  constexpr static unsigned int output_rows = 2;
-  constexpr static unsigned int output_cols = 2;
+  a64_s8q_nhwc_3x3_s1_output2x2_mla_depthfirst(const CPUInfo *) : Parent(2, 2, 3, 3, 1, 1) {}
 
-  constexpr static unsigned int input_rows = 4;
-  constexpr static unsigned int input_cols = 4;
+  arm_gemm::VLType get_vl_type(void) const override { return arm_gemm::VLType::None; }
 
-  constexpr static parameter_packing_fn pack_parameters = interleave_a64_s8q_3x3_mla::pack_parameters;
-  constexpr static parameter_sizing_fn get_packed_size = interleave_a64_s8q_3x3_mla::get_packed_size;
-
-  kern_type kernel = a64_s8q_nhwc_3x3_s1_output2x2_mla_depthfirst_impl;
-
-  a64_s8q_nhwc_3x3_s1_output2x2_mla_depthfirst(const CPUInfo *) {}
+  Parent::KernelType kernel = a64_s8q_nhwc_3x3_s1_output2x2_mla_depthfirst_impl;
+  Parent::KernelType get_kernel(void) const override { return kernel; }
+  unsigned int get_accumulator_depth_vl(void) const override { return 2; }
 };
 
 }  // namespace depthwise
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_s8q_nhwc_3x3_s1_output2x2_mla_depthfirst/generic.cpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_s8q_nhwc_3x3_s1_output2x2_mla_depthfirst/generic.cpp
index 3001276..0e8d16f 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_s8q_nhwc_3x3_s1_output2x2_mla_depthfirst/generic.cpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_s8q_nhwc_3x3_s1_output2x2_mla_depthfirst/generic.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -46,7 +46,7 @@
   struct Params
   {
     long unsigned int n_channels;
-    const int8_t *weights;
+    const void *weights;
     const int32_t *bias;
     const arm_gemm::Requantize32 *requant;
     const int32_t *const requant_muls;
@@ -57,7 +57,7 @@
     Params(
       long unsigned int n_channels,
       const int8_t *const *inptrs_raw,
-      const int8_t *const weights,
+      const void *const weights,
       const int32_t *const bias,
       const arm_gemm::Requantize32 &qp,
       const int32_t *const requant_muls,
@@ -91,513 +91,497 @@
                       requant_muls, requant_shifts, outptrs);
 
   __asm__ __volatile__(
+    "ldr x19, [%x[params], %[offsetof_Params_requant]]\n"
     "ldr x8, [%x[params], %[offsetof_Params_n_channels]]\n"
-    "mov x17, #0x0\n"
-    "ldr x16, [%x[params], %[offsetof_Params_weights]]\n"
+    "add x24, x19, %[offsetof_Requantize32_a_offset]\n"
+    "add x23, x19, %[offsetof_Requantize32_b_offset]\n"
+    "ldr x22, [%x[params], %[offsetof_Params_outptrs]]\n"
+    "add x21, x19, %[offsetof_Requantize32_c_offset]\n"
+    "add x20, x19, %[offsetof_Requantize32_minval]\n"
+    "ldr x17, [%x[params], %[offsetof_Params_weights]]\n"
+    "add x19, x19, %[offsetof_Requantize32_maxval]\n"
+    "ld1r { v22.16b }, [x24]\n"
+    "ld1r { v12.16b }, [x23]\n"
+    "lsr x16, x8, #0x3\n"
+    "ld1r { v14.8h }, [x21]\n"
+    "ld1r { v17.8h }, [x20]\n"
     "mov x15, #0x0\n"
-    "ldr x22, [%x[params], %[offsetof_Params_requant]]\n"
-    "add x14, %x[params], %[offsetof_Params_inptrs]\n"
+    "mov x14, #0x0\n"
+    "ld1r { v15.8h }, [x19]\n"
     "ldr x13, [%x[params], %[offsetof_Params_requant_muls]]\n"
-    "lsr x12, x8, #0x3\n"
+    "add x12, %x[params], %[offsetof_Params_inptrs]\n"
     "ldr x11, [%x[params], %[offsetof_Params_requant_shifts]]\n"
-    "add x19, x22, %[offsetof_Requantize32_a_offset]\n"
-    "ldr x21, [%x[params], %[offsetof_Params_outptrs]]\n"
-    "add x20, x22, %[offsetof_Requantize32_b_offset]\n"
-    "ld1r { v14.16b }, [x19]\n"
-    "add x19, x22, %[offsetof_Requantize32_c_offset]\n"
-    "ld1r { v9.16b }, [x20]\n"
-    "add x20, x22, %[offsetof_Requantize32_minval]\n"
-    "ld1r { v15.4s }, [x19]\n"
-    "add x19, x22, %[offsetof_Requantize32_maxval]\n"
-    "ld1r { v24.4s }, [x20]\n"
-    "ld1r { v12.4s }, [x19]\n"
-    "ldp x10, x9, [x21, #0x0]\n"
-    "ldp x28, x27, [x21, #0x10]\n"
-    "cbz x12, 3f\n"
-    "subs x12, x12, #0x1\n"
+    "ldp x10, x9, [x22, #0x0]\n"
+    "ldp x28, x27, [x22, #0x10]\n"
+    "cbz x16, 3f\n"
     "ldr x19, [%x[params], %[offsetof_Params_bias]]\n"
     "ldr q13, [x19, #0x0]\n"
-    "mov v17.16b, v13.16b\n"
-    "ldr q19, [x19, #0x10]\n"
+    "subs x16, x16, #0x1\n"
+    "mov v19.16b, v13.16b\n"
+    "ldr q26, [x19, #0x10]\n"
     "add x19, x19, #0x20\n"
-    "mov v16.16b, v13.16b\n"
     "str x19, [%x[params], %[offsetof_Params_bias]]\n"
-    "mov v23.16b, v13.16b\n"
-    "ldr d0, [x16, #0x0]\n"
-    "ssubl v0.8h, v0.8b, v9.8b\n"
-    "mov v25.16b, v19.16b\n"
-    "ldr d1, [x16, #0x8]\n"
-    "mov v21.16b, v19.16b\n"
-    "ldr d2, [x16, #0x10]\n"
-    "ssubl v1.8h, v1.8b, v9.8b\n"
-    "mov v20.16b, v19.16b\n"
-    "ldr d3, [x16, #0x18]\n"
-    "ldr d4, [x16, #0x20]\n"
-    "ssubl v2.8h, v2.8b, v9.8b\n"
-    "ldr d5, [x16, #0x28]\n"
-    "ssubl v3.8h, v3.8b, v9.8b\n"
-    "ldr d6, [x16, #0x30]\n"
-    "ldr d7, [x16, #0x38]\n"
-    "ssubl v4.8h, v4.8b, v9.8b\n"
-    "ldr d8, [x16, #0x40]\n"
-    "ssubl v5.8h, v5.8b, v9.8b\n"
-    "ldp x23, x22, [x14, #0x0]\n"
-    "ssubl v6.8h, v6.8b, v9.8b\n"
-    "ldp x21, x20, [x14, #0x10]\n"
-    "ssubl v7.8h, v7.8b, v9.8b\n"
-    "ssubl v8.8h, v8.8b, v9.8b\n"
-    "ldr x19, [x14, #0x20]\n"
-    "ldr d31, [x23, x17]\n"
-    "ssubl v31.8h, v31.8b, v14.8b\n"
-    "ldr d30, [x22, x17]\n"
-    "ldr d29, [x21, x17]\n"
-    "ssubl v30.8h, v30.8b, v14.8b\n"
-    "ldr d28, [x20, x17]\n"
-    "ldr d27, [x19, x17]\n"
-    "ssubl v29.8h, v29.8b, v14.8b\n"
-    "ssubl v28.8h, v28.8b, v14.8b\n"
-    "ssubl v27.8h, v27.8b, v14.8b\n"
+    "ldr d0, [x17, #0x0]\n"
+    "ldr d1, [x17, #0x8]\n"
+    "ldr d2, [x17, #0x10]\n"
+    "mov v11.16b, v26.16b\n"
+    "mov v18.16b, v13.16b\n"
+    "ldr d3, [x17, #0x18]\n"
+    "ldr d4, [x17, #0x20]\n"
+    "mov v24.16b, v26.16b\n"
+    "mov v9.16b, v13.16b\n"
+    "ldr d5, [x17, #0x28]\n"
+    "ldr d6, [x17, #0x30]\n"
+    "mov v23.16b, v26.16b\n"
+    "ssubl v0.8h, v0.8b, v12.8b\n"
+    "ldr d7, [x17, #0x38]\n"
+    "ldr d8, [x17, #0x40]\n"
+    "ssubl v1.8h, v1.8b, v12.8b\n"
+    "ssubl v2.8h, v2.8b, v12.8b\n"
+    "ldp x23, x22, [x12, #0x0]\n"
+    "ldp x21, x20, [x12, #0x10]\n"
+    "ssubl v3.8h, v3.8b, v12.8b\n"
+    "ssubl v4.8h, v4.8b, v12.8b\n"
+    "ldr x19, [x12, #0x20]\n"
+    "ldr d31, [x23, x15]\n"
+    "ssubl v5.8h, v5.8b, v12.8b\n"
+    "ssubl v6.8h, v6.8b, v12.8b\n"
+    "ldr d30, [x22, x15]\n"
+    "ldr d29, [x21, x15]\n"
+    "ssubl v7.8h, v7.8b, v12.8b\n"
+    "ssubl v8.8h, v8.8b, v12.8b\n"
+    "ldr d28, [x20, x15]\n"
+    "ldr d27, [x19, x15]\n"
+    "ssubl v31.8h, v31.8b, v22.8b\n"
+    "ssubl v30.8h, v30.8b, v22.8b\n"
+    "ssubl v29.8h, v29.8b, v22.8b\n"
+    "ssubl v28.8h, v28.8b, v22.8b\n"
+    "ssubl v27.8h, v27.8b, v22.8b\n"
     "beq 2f\n"
     "1:"  // Loop
     "smlal v13.4s, v31.4h, v4.4h\n"
-    "ldr x21, [x14, #0x28]\n"
-    "add x16, x16, #0x48\n"
-    "smlal2 v19.4s, v31.8h, v4.8h\n"
-    "ldr x20, [x14, #0x30]\n"
-    "subs x12, x12, #0x1\n"
-    "smlal v17.4s, v31.4h, v3.4h\n"
-    "ldr x26, [x14, #0x38]\n"
-    "smlal2 v25.4s, v31.8h, v3.8h\n"
-    "ldr x25, [x14, #0x40]\n"
-    "smlal v16.4s, v31.4h, v1.4h\n"
-    "ldr x19, [x14, #0x48]\n"
-    "smlal2 v21.4s, v31.8h, v1.8h\n"
-    "ldr x24, [x14, #0x50]\n"
-    "smlal v23.4s, v31.4h, v0.4h\n"
-    "ldr x23, [x14, #0x58]\n"
-    "smlal2 v20.4s, v31.8h, v0.8h\n"
-    "ldr d31, [x21, x17]\n"
-    "ssubl v31.8h, v31.8b, v14.8b\n"
+    "smlal2 v26.4s, v31.8h, v4.8h\n"
+    "ldr x21, [x12, #0x28]\n"
+    "ldr x26, [x12, #0x38]\n"
+    "smlal v19.4s, v31.4h, v3.4h\n"
+    "smlal2 v11.4s, v31.8h, v3.8h\n"
+    "ldr x20, [x12, #0x30]\n"
+    "ldr x25, [x12, #0x40]\n"
     "smlal v13.4s, v30.4h, v0.4h\n"
-    "ldr x22, [x14, #0x60]\n"
-    "smlal2 v19.4s, v30.8h, v0.8h\n"
-    "ldr d30, [x19, x17]\n"
-    "ssubl v30.8h, v30.8b, v14.8b\n"
-    "smlal v17.4s, v29.4h, v2.4h\n"
-    "ldr x21, [x14, #0x68]\n"
-    "smlal2 v25.4s, v29.8h, v2.8h\n"
-    "ldr d29, [x20, x17]\n"
-    "ssubl v29.8h, v29.8b, v14.8b\n"
+    "smlal2 v26.4s, v30.8h, v0.8h\n"
+    "ldr x19, [x12, #0x48]\n"
+    "ldr d30, [x19, x15]\n"
+    "smlal v19.4s, v29.4h, v2.4h\n"
+    "smlal2 v11.4s, v29.8h, v2.8h\n"
+    "ldr d29, [x20, x15]\n"
+    "ssubl v29.8h, v29.8b, v22.8b\n"
+    "smlal v18.4s, v31.4h, v1.4h\n"
+    "smlal2 v24.4s, v31.8h, v1.8h\n"
+    "ldr x24, [x12, #0x50]\n"
+    "ldr x23, [x12, #0x58]\n"
+    "smlal v9.4s, v31.4h, v0.4h\n"
+    "smlal2 v23.4s, v31.8h, v0.8h\n"
+    "ldr d31, [x21, x15]\n"
+    "ssubl v31.8h, v31.8b, v22.8b\n"
     "smlal v13.4s, v28.4h, v5.4h\n"
-    "ldr x20, [x14, #0x70]\n"
-    "smlal2 v19.4s, v28.8h, v5.8h\n"
-    "ldr x19, [x14, #0x78]\n"
-    "smlal v17.4s, v28.4h, v4.4h\n"
-    "ldr q26, [x13, #0x0]\n"
-    "smlal2 v25.4s, v28.8h, v4.8h\n"
-    "ldr q10, [x11, #0x0]\n"
-    "smlal v16.4s, v28.4h, v2.4h\n"
-    "ldr q11, [x13, #0x10]\n"
-    "add x13, x13, #0x20\n"
-    "smlal2 v21.4s, v28.8h, v2.8h\n"
-    "ldr q18, [x11, #0x10]\n"
-    "add x11, x11, #0x20\n"
-    "smlal v23.4s, v28.4h, v1.4h\n"
-    "smlal2 v20.4s, v28.8h, v1.8h\n"
-    "ldr d28, [x26, x17]\n"
-    "ssubl v28.8h, v28.8b, v14.8b\n"
-    "smlal v16.4s, v31.4h, v6.4h\n"
-    "smlal2 v21.4s, v31.8h, v6.8h\n"
-    "ldr d31, [x25, x17]\n"
-    "ssubl v31.8h, v31.8b, v14.8b\n"
+    "smlal2 v26.4s, v28.8h, v5.8h\n"
+    "ssubl v30.8h, v30.8b, v22.8b\n"
+    "ldr x22, [x12, #0x60]\n"
+    "smlal v19.4s, v28.4h, v4.4h\n"
+    "smlal2 v11.4s, v28.8h, v4.8h\n"
+    "ldr x21, [x12, #0x68]\n"
+    "ldr x20, [x12, #0x70]\n"
+    "smlal v18.4s, v28.4h, v2.4h\n"
+    "smlal2 v24.4s, v28.8h, v2.8h\n"
+    "ldr x19, [x12, #0x78]\n"
+    "ldr q21, [x13, #0x0]\n"
+    "smlal v9.4s, v28.4h, v1.4h\n"
+    "smlal2 v23.4s, v28.8h, v1.8h\n"
+    "ldr d28, [x26, x15]\n"
+    "ssubl v28.8h, v28.8b, v22.8b\n"
     "smlal v13.4s, v27.4h, v7.4h\n"
-    "smlal2 v19.4s, v27.8h, v7.8h\n"
-    "smlal v17.4s, v27.4h, v6.4h\n"
-    "smlal2 v25.4s, v27.8h, v6.8h\n"
-    "smlal v16.4s, v27.4h, v4.4h\n"
-    "smlal2 v21.4s, v27.8h, v4.8h\n"
-    "smlal v23.4s, v27.4h, v3.4h\n"
-    "smlal2 v20.4s, v27.8h, v3.8h\n"
+    "smlal2 v26.4s, v27.8h, v7.8h\n"
+    "ldr q25, [x11, #0x0]\n"
+    "ldr q10, [x13, #0x10]\n"
+    "smlal v19.4s, v27.4h, v6.4h\n"
+    "smlal2 v11.4s, v27.8h, v6.8h\n"
+    "ldr q16, [x11, #0x10]\n"
+    "add x17, x17, #0x48\n"
+    "smlal v18.4s, v31.4h, v6.4h\n"
+    "smlal2 v24.4s, v31.8h, v6.8h\n"
+    "ldr d31, [x25, x15]\n"
+    "ssubl v31.8h, v31.8b, v22.8b\n"
+    "smlal v9.4s, v27.4h, v3.4h\n"
+    "smlal2 v23.4s, v27.8h, v3.8h\n"
+    "subs x16, x16, #0x1\n"
+    "add x13, x13, #0x20\n"
     "smlal v13.4s, v28.4h, v1.4h\n"
-    "smlal2 v19.4s, v28.8h, v1.8h\n"
-    "smlal v23.4s, v29.4h, v8.4h\n"
-    "smlal2 v20.4s, v29.8h, v8.8h\n"
-    "ldr d29, [x24, x17]\n"
-    "ssubl v29.8h, v29.8b, v14.8b\n"
-    "smlal v17.4s, v28.4h, v0.4h\n"
-    "smlal2 v25.4s, v28.8h, v0.8h\n"
-    "ldr d28, [x23, x17]\n"
-    "ssubl v28.8h, v28.8b, v14.8b\n"
+    "smlal2 v26.4s, v28.8h, v1.8h\n"
+    "add x11, x11, #0x20\n"
+    "smlal v19.4s, v28.4h, v0.4h\n"
+    "smlal2 v11.4s, v28.8h, v0.8h\n"
+    "ldr d28, [x23, x15]\n"
+    "ssubl v28.8h, v28.8b, v22.8b\n"
+    "smlal v18.4s, v27.4h, v4.4h\n"
+    "smlal v9.4s, v29.4h, v8.4h\n"
+    "smlal2 v24.4s, v27.8h, v4.8h\n"
+    "smlal2 v23.4s, v29.8h, v8.8h\n"
+    "ldr d29, [x24, x15]\n"
+    "ssubl v29.8h, v29.8b, v22.8b\n"
     "smlal v13.4s, v31.4h, v2.4h\n"
-    "smlal2 v19.4s, v31.8h, v2.8h\n"
-    "smlal v17.4s, v31.4h, v1.4h\n"
-    "smlal2 v25.4s, v31.8h, v1.8h\n"
-    "ldr d31, [x22, x17]\n"
-    "ssubl v31.8h, v31.8b, v14.8b\n"
+    "smlal2 v26.4s, v31.8h, v2.8h\n"
+    "smlal v19.4s, v31.4h, v1.4h\n"
+    "smlal2 v11.4s, v31.8h, v1.8h\n"
+    "ldr d31, [x22, x15]\n"
+    "ssubl v31.8h, v31.8b, v22.8b\n"
+    "smlal v18.4s, v30.4h, v5.4h\n"
+    "smlal v9.4s, v30.4h, v4.4h\n"
     "smlal v13.4s, v30.4h, v8.4h\n"
-    "smlal2 v19.4s, v30.8h, v8.8h\n"
-    "smlal v17.4s, v30.4h, v7.4h\n"
-    "smlal2 v25.4s, v30.8h, v7.8h\n"
-    "smlal v16.4s, v30.4h, v5.4h\n"
-    "smlal2 v21.4s, v30.8h, v5.8h\n"
-    "smlal v23.4s, v30.4h, v4.4h\n"
-    "smlal2 v20.4s, v30.8h, v4.8h\n"
-    "ldr d30, [x21, x17]\n"
-    "ssubl v30.8h, v30.8b, v14.8b\n"
+    "smlal2 v26.4s, v30.8h, v8.8h\n"
+    "smlal v19.4s, v30.4h, v7.4h\n"
+    "smlal2 v11.4s, v30.8h, v7.8h\n"
+    "smlal2 v24.4s, v30.8h, v5.8h\n"
+    "smlal2 v23.4s, v30.8h, v4.8h\n"
+    "ldr d30, [x21, x15]\n"
+    "ssubl v30.8h, v30.8b, v22.8b\n"
+    "smlal v18.4s, v29.4h, v0.4h\n"
+    "smlal v9.4s, v28.4h, v2.4h\n"
     "smlal v13.4s, v29.4h, v3.4h\n"
-    "smlal2 v19.4s, v29.8h, v3.8h\n"
-    "smlal v16.4s, v29.4h, v0.4h\n"
-    "smlal2 v21.4s, v29.8h, v0.8h\n"
-    "ldr d29, [x20, x17]\n"
-    "ssubl v29.8h, v29.8b, v14.8b\n"
-    "smlal v17.4s, v28.4h, v5.4h\n"
-    "smlal2 v25.4s, v28.8h, v5.8h\n"
-    "smlal v23.4s, v28.4h, v2.4h\n"
-    "smlal2 v20.4s, v28.8h, v2.8h\n"
-    "ldr d28, [x19, x17]\n"
-    "add x17, x17, #0x8\n"
-    "smlal v13.4s, v31.4h, v6.4h\n"
-    "ssubl v28.8h, v28.8b, v14.8b\n"
-    "smlal2 v19.4s, v31.8h, v6.8h\n"
-    "smlal v16.4s, v31.4h, v3.4h\n"
-    "smlal2 v21.4s, v31.8h, v3.8h\n"
-    "smlal v17.4s, v30.4h, v8.4h\n"
-    "smlal2 v25.4s, v30.8h, v8.8h\n"
-    "smlal v23.4s, v30.4h, v5.4h\n"
-    "smlal2 v20.4s, v30.8h, v5.8h\n"
-    "smlal v16.4s, v29.4h, v7.4h\n"
-    "smlal2 v21.4s, v29.8h, v7.8h\n"
-    "smlal v23.4s, v29.4h, v6.4h\n"
-    "smlal2 v20.4s, v29.8h, v6.8h\n"
-    "smlal v16.4s, v28.4h, v8.4h\n"
-    "smlal2 v21.4s, v28.8h, v8.8h\n"
-    "smlal v23.4s, v28.4h, v7.4h\n"
-    "smlal2 v20.4s, v28.8h, v7.8h\n"
-    "sqrdmulh v13.4s, v13.4s, v26.4s\n"
-    "sqrdmulh v19.4s, v19.4s, v11.4s\n"
-    "sqrdmulh v17.4s, v17.4s, v26.4s\n"
-    "sqrdmulh v25.4s, v25.4s, v11.4s\n"
-    "and v22.16b, v13.16b, v10.16b\n"
-    "sshr v22.4s, v22.4s, #0x1f\n"
-    "and v28.16b, v19.16b, v18.16b\n"
-    "and v3.16b, v17.16b, v10.16b\n"
-    "sshr v28.4s, v28.4s, #0x1f\n"
-    "and v6.16b, v25.16b, v18.16b\n"
-    "sqrdmulh v16.4s, v16.4s, v26.4s\n"
-    "sshr v3.4s, v3.4s, #0x1f\n"
-    "sqrdmulh v21.4s, v21.4s, v11.4s\n"
-    "sshr v6.4s, v6.4s, #0x1f\n"
-    "sqadd v13.4s, v13.4s, v22.4s\n"
-    "sqrdmulh v23.4s, v23.4s, v26.4s\n"
-    "and v0.16b, v16.16b, v10.16b\n"
-    "sshr v0.4s, v0.4s, #0x1f\n"
-    "srshl v13.4s, v13.4s, v10.4s\n"
-    "sqadd v19.4s, v19.4s, v28.4s\n"
-    "sqadd v17.4s, v17.4s, v3.4s\n"
-    "sqadd v25.4s, v25.4s, v6.4s\n"
-    "and v29.16b, v21.16b, v18.16b\n"
-    "sshr v29.4s, v29.4s, #0x1f\n"
-    "add v13.4s, v13.4s, v15.4s\n"
-    "srshl v19.4s, v19.4s, v18.4s\n"
-    "srshl v17.4s, v17.4s, v10.4s\n"
-    "srshl v25.4s, v25.4s, v18.4s\n"
-    "smin v13.4s, v13.4s, v12.4s\n"
-    "add v19.4s, v19.4s, v15.4s\n"
-    "add v17.4s, v17.4s, v15.4s\n"
-    "smax v13.4s, v13.4s, v24.4s\n"
-    "smin v19.4s, v19.4s, v12.4s\n"
-    "smin v17.4s, v17.4s, v12.4s\n"
-    "add v25.4s, v25.4s, v15.4s\n"
-    "smax v19.4s, v19.4s, v24.4s\n"
-    "smax v17.4s, v17.4s, v24.4s\n"
-    "smin v25.4s, v25.4s, v12.4s\n"
-    "uzp1 v13.16b, v13.16b, v19.16b\n"
-    "sqadd v16.4s, v16.4s, v0.4s\n"
-    "uzp1 v13.16b, v13.16b, v13.16b\n"
-    "str d13, [x10, x15]\n"
-    "smax v25.4s, v25.4s, v24.4s\n"
-    "sqadd v21.4s, v21.4s, v29.4s\n"
-    "srshl v16.4s, v16.4s, v10.4s\n"
-    "and v3.16b, v23.16b, v10.16b\n"
-    "sshr v3.4s, v3.4s, #0x1f\n"
-    "uzp1 v17.16b, v17.16b, v25.16b\n"
-    "add v16.4s, v16.4s, v15.4s\n"
-    "srshl v21.4s, v21.4s, v18.4s\n"
-    "uzp1 v17.16b, v17.16b, v17.16b\n"
-    "str d17, [x9, x15]\n"
-    "smin v16.4s, v16.4s, v12.4s\n"
-    "sqrdmulh v20.4s, v20.4s, v11.4s\n"
-    "add v21.4s, v21.4s, v15.4s\n"
-    "sqadd v23.4s, v23.4s, v3.4s\n"
-    "smax v16.4s, v16.4s, v24.4s\n"
-    "smin v21.4s, v21.4s, v12.4s\n"
-    "and v25.16b, v20.16b, v18.16b\n"
-    "sshr v25.4s, v25.4s, #0x1f\n"
-    "smax v21.4s, v21.4s, v24.4s\n"
-    "srshl v23.4s, v23.4s, v10.4s\n"
-    "uzp1 v16.16b, v16.16b, v21.16b\n"
-    "add v23.4s, v23.4s, v15.4s\n"
-    "uzp1 v16.16b, v16.16b, v16.16b\n"
-    "str d16, [x28, x15]\n"
-    "smin v23.4s, v23.4s, v12.4s\n"
-    "sqadd v20.4s, v20.4s, v25.4s\n"
-    "smax v23.4s, v23.4s, v24.4s\n"
-    "srshl v20.4s, v20.4s, v18.4s\n"
-    "add v20.4s, v20.4s, v15.4s\n"
-    "smin v20.4s, v20.4s, v12.4s\n"
-    "smax v20.4s, v20.4s, v24.4s\n"
-    "uzp1 v23.16b, v23.16b, v20.16b\n"
-    "uzp1 v23.16b, v23.16b, v23.16b\n"
-    "str d23, [x27, x15]\n"
+    "smlal2 v26.4s, v29.8h, v3.8h\n"
+    "smlal2 v24.4s, v29.8h, v0.8h\n"
+    "ldr d29, [x20, x15]\n"
+    "smlal2 v23.4s, v28.8h, v2.8h\n"
+    "ssubl v29.8h, v29.8b, v22.8b\n"
+    "smlal v18.4s, v31.4h, v3.4h\n"
+    "smlal v9.4s, v30.4h, v5.4h\n"
+    "smlal v19.4s, v28.4h, v5.4h\n"
+    "smlal2 v11.4s, v28.8h, v5.8h\n"
+    "ldr d28, [x19, x15]\n"
+    "ssubl v28.8h, v28.8b, v22.8b\n"
+    "smlal2 v24.4s, v31.8h, v3.8h\n"
+    "smlal2 v23.4s, v30.8h, v5.8h\n"
     "add x15, x15, #0x8\n"
+    "smlal v18.4s, v29.4h, v7.4h\n"
+    "smlal v9.4s, v29.4h, v6.4h\n"
+    "smlal2 v24.4s, v29.8h, v7.8h\n"
+    "smlal2 v23.4s, v29.8h, v6.8h\n"
+    "smlal v13.4s, v31.4h, v6.4h\n"
+    "smlal v19.4s, v30.4h, v8.4h\n"
+    "sqdmulh v13.4s, v13.4s, v21.4s\n"
+    "smlal v18.4s, v28.4h, v8.4h\n"
+    "smlal v9.4s, v28.4h, v7.4h\n"
+    "sqdmulh v19.4s, v19.4s, v21.4s\n"
+    "smlal2 v26.4s, v31.8h, v6.8h\n"
+    "smlal2 v11.4s, v30.8h, v8.8h\n"
+    "sqdmulh v18.4s, v18.4s, v21.4s\n"
+    "smlal2 v24.4s, v28.8h, v8.8h\n"
+    "smlal2 v23.4s, v28.8h, v7.8h\n"
+    "sqdmulh v9.4s, v9.4s, v21.4s\n"
+    "and v7.16b, v13.16b, v25.16b\n"
+    "sqdmulh v26.4s, v26.4s, v10.4s\n"
+    "and v4.16b, v19.16b, v25.16b\n"
+    "sqdmulh v11.4s, v11.4s, v10.4s\n"
+    "and v21.16b, v18.16b, v25.16b\n"
+    "sqdmulh v24.4s, v24.4s, v10.4s\n"
+    "and v20.16b, v9.16b, v25.16b\n"
+    "sqdmulh v23.4s, v23.4s, v10.4s\n"
+    "sshr v7.4s, v7.4s, #0x1f\n"
+    "and v29.16b, v26.16b, v16.16b\n"
+    "sshr v4.4s, v4.4s, #0x1f\n"
+    "and v10.16b, v11.16b, v16.16b\n"
+    "sshr v21.4s, v21.4s, #0x1f\n"
+    "and v31.16b, v24.16b, v16.16b\n"
+    "sshr v20.4s, v20.4s, #0x1f\n"
+    "and v30.16b, v23.16b, v16.16b\n"
+    "sqadd v13.4s, v13.4s, v7.4s\n"
+    "sshr v29.4s, v29.4s, #0x1f\n"
+    "sqadd v19.4s, v19.4s, v4.4s\n"
+    "sshr v10.4s, v10.4s, #0x1f\n"
+    "sqadd v18.4s, v18.4s, v21.4s\n"
+    "sshr v31.4s, v31.4s, #0x1f\n"
+    "sqadd v9.4s, v9.4s, v20.4s\n"
+    "sshr v30.4s, v30.4s, #0x1f\n"
+    "srshl v13.4s, v13.4s, v25.4s\n"
+    "sqadd v26.4s, v26.4s, v29.4s\n"
+    "srshl v19.4s, v19.4s, v25.4s\n"
+    "sqadd v11.4s, v11.4s, v10.4s\n"
+    "srshl v18.4s, v18.4s, v25.4s\n"
+    "sqadd v24.4s, v24.4s, v31.4s\n"
+    "srshl v9.4s, v9.4s, v25.4s\n"
+    "sqadd v23.4s, v23.4s, v30.4s\n"
+    "srshl v26.4s, v26.4s, v16.4s\n"
+    "sqxtn v13.4h, v13.4s\n"
+    "srshl v11.4s, v11.4s, v16.4s\n"
+    "sqxtn v19.4h, v19.4s\n"
+    "srshl v24.4s, v24.4s, v16.4s\n"
+    "sqxtn v18.4h, v18.4s\n"
+    "srshl v23.4s, v23.4s, v16.4s\n"
+    "sqxtn v9.4h, v9.4s\n"
+    "sqxtn2 v13.8h, v26.4s\n"
+    "sqxtn2 v19.8h, v11.4s\n"
+    "sqxtn2 v18.8h, v24.4s\n"
+    "sqxtn2 v9.8h, v23.4s\n"
+    "sqadd v13.8h, v13.8h, v14.8h\n"
+    "sqadd v19.8h, v19.8h, v14.8h\n"
+    "sqadd v18.8h, v18.8h, v14.8h\n"
+    "sqadd v9.8h, v9.8h, v14.8h\n"
+    "smax v13.8h, v13.8h, v17.8h\n"
+    "smax v19.8h, v19.8h, v17.8h\n"
+    "smax v18.8h, v18.8h, v17.8h\n"
+    "smax v9.8h, v9.8h, v17.8h\n"
+    "smin v13.8h, v13.8h, v15.8h\n"
+    "smin v19.8h, v19.8h, v15.8h\n"
+    "smin v18.8h, v18.8h, v15.8h\n"
+    "smin v9.8h, v9.8h, v15.8h\n"
+    "uzp1 v13.16b, v13.16b, v13.16b\n"
+    "uzp1 v19.16b, v19.16b, v19.16b\n"
+    "str d13, [x10, x14]\n"
+    "uzp1 v18.16b, v18.16b, v18.16b\n"
+    "uzp1 v9.16b, v9.16b, v9.16b\n"
+    "str d19, [x9, x14]\n"
+    "str d18, [x28, x14]\n"
+    "str d9, [x27, x14]\n"
     "ldr x19, [%x[params], %[offsetof_Params_bias]]\n"
     "ldr q13, [x19, #0x0]\n"
-    "mov v17.16b, v13.16b\n"
-    "ldr q19, [x19, #0x10]\n"
+    "add x14, x14, #0x8\n"
+    "ldr q26, [x19, #0x10]\n"
     "add x19, x19, #0x20\n"
-    "mov v16.16b, v13.16b\n"
     "str x19, [%x[params], %[offsetof_Params_bias]]\n"
-    "mov v23.16b, v13.16b\n"
-    "ldr d0, [x16, #0x0]\n"
-    "ssubl v0.8h, v0.8b, v9.8b\n"
-    "mov v25.16b, v19.16b\n"
-    "ldr d1, [x16, #0x8]\n"
-    "mov v21.16b, v19.16b\n"
-    "ldr d2, [x16, #0x10]\n"
-    "ssubl v1.8h, v1.8b, v9.8b\n"
-    "mov v20.16b, v19.16b\n"
-    "ldr d3, [x16, #0x18]\n"
-    "ldr d4, [x16, #0x20]\n"
-    "ssubl v2.8h, v2.8b, v9.8b\n"
-    "ldr d5, [x16, #0x28]\n"
-    "ssubl v3.8h, v3.8b, v9.8b\n"
-    "ldr d6, [x16, #0x30]\n"
-    "ldr d7, [x16, #0x38]\n"
-    "ssubl v4.8h, v4.8b, v9.8b\n"
-    "ldr d8, [x16, #0x40]\n"
-    "ssubl v5.8h, v5.8b, v9.8b\n"
-    "ldp x23, x22, [x14, #0x0]\n"
-    "ssubl v6.8h, v6.8b, v9.8b\n"
-    "ldp x21, x20, [x14, #0x10]\n"
-    "ssubl v7.8h, v7.8b, v9.8b\n"
-    "ssubl v8.8h, v8.8b, v9.8b\n"
-    "ldr x19, [x14, #0x20]\n"
-    "ldr d31, [x23, x17]\n"
-    "ssubl v31.8h, v31.8b, v14.8b\n"
-    "ldr d30, [x22, x17]\n"
-    "ldr d29, [x21, x17]\n"
-    "ssubl v30.8h, v30.8b, v14.8b\n"
-    "ldr d28, [x20, x17]\n"
-    "ldr d27, [x19, x17]\n"
-    "ssubl v29.8h, v29.8b, v14.8b\n"
-    "ssubl v28.8h, v28.8b, v14.8b\n"
-    "ssubl v27.8h, v27.8b, v14.8b\n"
+    "ldr d0, [x17, #0x0]\n"
+    "ldr d1, [x17, #0x8]\n"
+    "ldr d2, [x17, #0x10]\n"
+    "mov v19.16b, v13.16b\n"
+    "mov v11.16b, v26.16b\n"
+    "ldr d3, [x17, #0x18]\n"
+    "ldr d4, [x17, #0x20]\n"
+    "mov v18.16b, v13.16b\n"
+    "mov v24.16b, v26.16b\n"
+    "ldr d5, [x17, #0x28]\n"
+    "ldr d6, [x17, #0x30]\n"
+    "mov v9.16b, v13.16b\n"
+    "mov v23.16b, v26.16b\n"
+    "ldr d7, [x17, #0x38]\n"
+    "ldr d8, [x17, #0x40]\n"
+    "ssubl v0.8h, v0.8b, v12.8b\n"
+    "ssubl v1.8h, v1.8b, v12.8b\n"
+    "ldp x23, x22, [x12, #0x0]\n"
+    "ldp x21, x20, [x12, #0x10]\n"
+    "ssubl v2.8h, v2.8b, v12.8b\n"
+    "ssubl v3.8h, v3.8b, v12.8b\n"
+    "ldr x19, [x12, #0x20]\n"
+    "ldr d31, [x23, x15]\n"
+    "ssubl v4.8h, v4.8b, v12.8b\n"
+    "ssubl v5.8h, v5.8b, v12.8b\n"
+    "ldr d30, [x22, x15]\n"
+    "ldr d29, [x21, x15]\n"
+    "ssubl v6.8h, v6.8b, v12.8b\n"
+    "ssubl v7.8h, v7.8b, v12.8b\n"
+    "ldr d28, [x20, x15]\n"
+    "ldr d27, [x19, x15]\n"
+    "ssubl v8.8h, v8.8b, v12.8b\n"
+    "ssubl v31.8h, v31.8b, v22.8b\n"
+    "ssubl v30.8h, v30.8b, v22.8b\n"
+    "ssubl v29.8h, v29.8b, v22.8b\n"
+    "ssubl v28.8h, v28.8b, v22.8b\n"
+    "ssubl v27.8h, v27.8b, v22.8b\n"
     "bgt 1b\n"
     "2:"  // Tail
     "smlal v13.4s, v31.4h, v4.4h\n"
-    "ldr x21, [x14, #0x28]\n"
-    "tst x8, #0x7\n"
-    "smlal2 v19.4s, v31.8h, v4.8h\n"
-    "ldr x20, [x14, #0x30]\n"
-    "smlal v17.4s, v31.4h, v3.4h\n"
-    "ldr x26, [x14, #0x38]\n"
-    "smlal2 v25.4s, v31.8h, v3.8h\n"
-    "ldr x25, [x14, #0x40]\n"
-    "smlal v16.4s, v31.4h, v1.4h\n"
-    "ldr x19, [x14, #0x48]\n"
-    "smlal2 v21.4s, v31.8h, v1.8h\n"
-    "ldr x24, [x14, #0x50]\n"
-    "smlal v23.4s, v31.4h, v0.4h\n"
-    "ldr x23, [x14, #0x58]\n"
-    "smlal2 v20.4s, v31.8h, v0.8h\n"
-    "ldr d31, [x21, x17]\n"
-    "ssubl v31.8h, v31.8b, v14.8b\n"
+    "smlal2 v26.4s, v31.8h, v4.8h\n"
+    "ldr x21, [x12, #0x28]\n"
+    "ldr x26, [x12, #0x38]\n"
+    "smlal v19.4s, v31.4h, v3.4h\n"
+    "smlal2 v11.4s, v31.8h, v3.8h\n"
+    "ldr x20, [x12, #0x30]\n"
+    "ldr x25, [x12, #0x40]\n"
     "smlal v13.4s, v30.4h, v0.4h\n"
-    "ldr x22, [x14, #0x60]\n"
-    "smlal2 v19.4s, v30.8h, v0.8h\n"
-    "ldr d30, [x19, x17]\n"
-    "ssubl v30.8h, v30.8b, v14.8b\n"
-    "smlal v17.4s, v29.4h, v2.4h\n"
-    "ldr x21, [x14, #0x68]\n"
-    "smlal2 v25.4s, v29.8h, v2.8h\n"
-    "ldr d29, [x20, x17]\n"
-    "ssubl v29.8h, v29.8b, v14.8b\n"
+    "smlal2 v26.4s, v30.8h, v0.8h\n"
+    "ldr x19, [x12, #0x48]\n"
+    "ldr d30, [x19, x15]\n"
+    "smlal v19.4s, v29.4h, v2.4h\n"
+    "smlal2 v11.4s, v29.8h, v2.8h\n"
+    "ldr d29, [x20, x15]\n"
+    "ssubl v29.8h, v29.8b, v22.8b\n"
+    "smlal v18.4s, v31.4h, v1.4h\n"
+    "smlal2 v24.4s, v31.8h, v1.8h\n"
+    "ldr x24, [x12, #0x50]\n"
+    "ldr x23, [x12, #0x58]\n"
+    "smlal v9.4s, v31.4h, v0.4h\n"
+    "smlal2 v23.4s, v31.8h, v0.8h\n"
+    "ldr d31, [x21, x15]\n"
+    "ssubl v31.8h, v31.8b, v22.8b\n"
     "smlal v13.4s, v28.4h, v5.4h\n"
-    "ldr x20, [x14, #0x70]\n"
-    "smlal2 v19.4s, v28.8h, v5.8h\n"
-    "ldr x19, [x14, #0x78]\n"
-    "smlal v17.4s, v28.4h, v4.4h\n"
-    "ldr q26, [x13, #0x0]\n"
-    "smlal2 v25.4s, v28.8h, v4.8h\n"
-    "ldr q10, [x11, #0x0]\n"
-    "smlal v16.4s, v28.4h, v2.4h\n"
-    "ldr q11, [x13, #0x10]\n"
-    "add x13, x13, #0x20\n"
-    "smlal2 v21.4s, v28.8h, v2.8h\n"
-    "ldr q18, [x11, #0x10]\n"
-    "add x11, x11, #0x20\n"
-    "smlal v23.4s, v28.4h, v1.4h\n"
-    "smlal2 v20.4s, v28.8h, v1.8h\n"
-    "ldr d28, [x26, x17]\n"
-    "ssubl v28.8h, v28.8b, v14.8b\n"
-    "smlal v16.4s, v31.4h, v6.4h\n"
-    "smlal2 v21.4s, v31.8h, v6.8h\n"
-    "ldr d31, [x25, x17]\n"
-    "ssubl v31.8h, v31.8b, v14.8b\n"
+    "smlal2 v26.4s, v28.8h, v5.8h\n"
+    "ssubl v30.8h, v30.8b, v22.8b\n"
+    "ldr x22, [x12, #0x60]\n"
+    "smlal v19.4s, v28.4h, v4.4h\n"
+    "smlal2 v11.4s, v28.8h, v4.8h\n"
+    "ldr x21, [x12, #0x68]\n"
+    "ldr x20, [x12, #0x70]\n"
+    "smlal v18.4s, v28.4h, v2.4h\n"
+    "smlal2 v24.4s, v28.8h, v2.8h\n"
+    "ldr x19, [x12, #0x78]\n"
+    "ldr q21, [x13, #0x0]\n"
+    "smlal v9.4s, v28.4h, v1.4h\n"
+    "smlal2 v23.4s, v28.8h, v1.8h\n"
+    "ldr d28, [x26, x15]\n"
+    "ssubl v28.8h, v28.8b, v22.8b\n"
     "smlal v13.4s, v27.4h, v7.4h\n"
-    "smlal2 v19.4s, v27.8h, v7.8h\n"
-    "smlal v17.4s, v27.4h, v6.4h\n"
-    "smlal2 v25.4s, v27.8h, v6.8h\n"
-    "smlal v16.4s, v27.4h, v4.4h\n"
-    "smlal2 v21.4s, v27.8h, v4.8h\n"
-    "smlal v23.4s, v27.4h, v3.4h\n"
-    "smlal2 v20.4s, v27.8h, v3.8h\n"
+    "smlal2 v26.4s, v27.8h, v7.8h\n"
+    "ldr q25, [x11, #0x0]\n"
+    "ldr q10, [x13, #0x10]\n"
+    "smlal v19.4s, v27.4h, v6.4h\n"
+    "smlal2 v11.4s, v27.8h, v6.8h\n"
+    "ldr q16, [x11, #0x10]\n"
+    "tst x8, #0x7\n"
+    "smlal v18.4s, v31.4h, v6.4h\n"
+    "smlal2 v24.4s, v31.8h, v6.8h\n"
+    "ldr d31, [x25, x15]\n"
+    "ssubl v31.8h, v31.8b, v22.8b\n"
+    "smlal v9.4s, v27.4h, v3.4h\n"
+    "smlal2 v23.4s, v27.8h, v3.8h\n"
+    "add x13, x13, #0x20\n"
+    "add x11, x11, #0x20\n"
     "smlal v13.4s, v28.4h, v1.4h\n"
-    "smlal2 v19.4s, v28.8h, v1.8h\n"
-    "smlal v23.4s, v29.4h, v8.4h\n"
-    "smlal2 v20.4s, v29.8h, v8.8h\n"
-    "ldr d29, [x24, x17]\n"
-    "ssubl v29.8h, v29.8b, v14.8b\n"
-    "smlal v17.4s, v28.4h, v0.4h\n"
-    "smlal2 v25.4s, v28.8h, v0.8h\n"
-    "ldr d28, [x23, x17]\n"
-    "ssubl v28.8h, v28.8b, v14.8b\n"
+    "smlal2 v26.4s, v28.8h, v1.8h\n"
+    "smlal v19.4s, v28.4h, v0.4h\n"
+    "smlal2 v11.4s, v28.8h, v0.8h\n"
+    "ldr d28, [x23, x15]\n"
+    "ssubl v28.8h, v28.8b, v22.8b\n"
+    "smlal v18.4s, v27.4h, v4.4h\n"
+    "smlal v9.4s, v29.4h, v8.4h\n"
+    "smlal2 v24.4s, v27.8h, v4.8h\n"
+    "smlal2 v23.4s, v29.8h, v8.8h\n"
+    "ldr d29, [x24, x15]\n"
+    "ssubl v29.8h, v29.8b, v22.8b\n"
     "smlal v13.4s, v31.4h, v2.4h\n"
-    "smlal2 v19.4s, v31.8h, v2.8h\n"
-    "smlal v17.4s, v31.4h, v1.4h\n"
-    "smlal2 v25.4s, v31.8h, v1.8h\n"
-    "ldr d31, [x22, x17]\n"
-    "ssubl v31.8h, v31.8b, v14.8b\n"
+    "smlal2 v26.4s, v31.8h, v2.8h\n"
+    "smlal v19.4s, v31.4h, v1.4h\n"
+    "smlal2 v11.4s, v31.8h, v1.8h\n"
+    "ldr d31, [x22, x15]\n"
+    "ssubl v31.8h, v31.8b, v22.8b\n"
+    "smlal v18.4s, v30.4h, v5.4h\n"
+    "smlal v9.4s, v30.4h, v4.4h\n"
     "smlal v13.4s, v30.4h, v8.4h\n"
-    "smlal2 v19.4s, v30.8h, v8.8h\n"
-    "smlal v17.4s, v30.4h, v7.4h\n"
-    "smlal2 v25.4s, v30.8h, v7.8h\n"
-    "smlal v16.4s, v30.4h, v5.4h\n"
-    "smlal2 v21.4s, v30.8h, v5.8h\n"
-    "smlal v23.4s, v30.4h, v4.4h\n"
-    "smlal2 v20.4s, v30.8h, v4.8h\n"
-    "ldr d30, [x21, x17]\n"
-    "ssubl v30.8h, v30.8b, v14.8b\n"
+    "smlal2 v26.4s, v30.8h, v8.8h\n"
+    "smlal v19.4s, v30.4h, v7.4h\n"
+    "smlal2 v11.4s, v30.8h, v7.8h\n"
+    "smlal2 v24.4s, v30.8h, v5.8h\n"
+    "smlal2 v23.4s, v30.8h, v4.8h\n"
+    "ldr d30, [x21, x15]\n"
+    "ssubl v30.8h, v30.8b, v22.8b\n"
+    "smlal v18.4s, v29.4h, v0.4h\n"
+    "smlal v9.4s, v28.4h, v2.4h\n"
     "smlal v13.4s, v29.4h, v3.4h\n"
-    "smlal2 v19.4s, v29.8h, v3.8h\n"
-    "smlal v16.4s, v29.4h, v0.4h\n"
-    "smlal2 v21.4s, v29.8h, v0.8h\n"
-    "ldr d29, [x20, x17]\n"
-    "ssubl v29.8h, v29.8b, v14.8b\n"
-    "smlal v17.4s, v28.4h, v5.4h\n"
-    "smlal2 v25.4s, v28.8h, v5.8h\n"
-    "smlal v23.4s, v28.4h, v2.4h\n"
-    "smlal2 v20.4s, v28.8h, v2.8h\n"
-    "ldr d28, [x19, x17]\n"
-    "add x17, x17, #0x8\n"
-    "smlal v13.4s, v31.4h, v6.4h\n"
-    "ssubl v28.8h, v28.8b, v14.8b\n"
-    "smlal2 v19.4s, v31.8h, v6.8h\n"
-    "smlal v16.4s, v31.4h, v3.4h\n"
-    "smlal2 v21.4s, v31.8h, v3.8h\n"
-    "smlal v17.4s, v30.4h, v8.4h\n"
-    "smlal2 v25.4s, v30.8h, v8.8h\n"
-    "smlal v23.4s, v30.4h, v5.4h\n"
-    "smlal2 v20.4s, v30.8h, v5.8h\n"
-    "smlal v16.4s, v29.4h, v7.4h\n"
-    "smlal2 v21.4s, v29.8h, v7.8h\n"
-    "smlal v23.4s, v29.4h, v6.4h\n"
-    "smlal2 v20.4s, v29.8h, v6.8h\n"
-    "smlal v16.4s, v28.4h, v8.4h\n"
-    "smlal2 v21.4s, v28.8h, v8.8h\n"
-    "smlal v23.4s, v28.4h, v7.4h\n"
-    "smlal2 v20.4s, v28.8h, v7.8h\n"
-    "sqrdmulh v13.4s, v13.4s, v26.4s\n"
-    "sqrdmulh v19.4s, v19.4s, v11.4s\n"
-    "sqrdmulh v17.4s, v17.4s, v26.4s\n"
-    "sqrdmulh v25.4s, v25.4s, v11.4s\n"
-    "and v22.16b, v13.16b, v10.16b\n"
-    "sshr v22.4s, v22.4s, #0x1f\n"
-    "and v28.16b, v19.16b, v18.16b\n"
-    "and v3.16b, v17.16b, v10.16b\n"
-    "sshr v28.4s, v28.4s, #0x1f\n"
-    "and v6.16b, v25.16b, v18.16b\n"
-    "sqrdmulh v16.4s, v16.4s, v26.4s\n"
-    "sshr v3.4s, v3.4s, #0x1f\n"
-    "sqrdmulh v21.4s, v21.4s, v11.4s\n"
-    "sshr v6.4s, v6.4s, #0x1f\n"
-    "sqadd v13.4s, v13.4s, v22.4s\n"
-    "sqrdmulh v23.4s, v23.4s, v26.4s\n"
-    "and v0.16b, v16.16b, v10.16b\n"
-    "sshr v0.4s, v0.4s, #0x1f\n"
-    "srshl v13.4s, v13.4s, v10.4s\n"
-    "sqadd v19.4s, v19.4s, v28.4s\n"
-    "sqadd v17.4s, v17.4s, v3.4s\n"
-    "sqadd v25.4s, v25.4s, v6.4s\n"
-    "and v29.16b, v21.16b, v18.16b\n"
-    "sshr v29.4s, v29.4s, #0x1f\n"
-    "add v13.4s, v13.4s, v15.4s\n"
-    "srshl v19.4s, v19.4s, v18.4s\n"
-    "srshl v17.4s, v17.4s, v10.4s\n"
-    "srshl v25.4s, v25.4s, v18.4s\n"
-    "smin v13.4s, v13.4s, v12.4s\n"
-    "add v19.4s, v19.4s, v15.4s\n"
-    "add v17.4s, v17.4s, v15.4s\n"
-    "smax v13.4s, v13.4s, v24.4s\n"
-    "smin v19.4s, v19.4s, v12.4s\n"
-    "smin v17.4s, v17.4s, v12.4s\n"
-    "add v25.4s, v25.4s, v15.4s\n"
-    "smax v19.4s, v19.4s, v24.4s\n"
-    "smax v17.4s, v17.4s, v24.4s\n"
-    "smin v25.4s, v25.4s, v12.4s\n"
-    "uzp1 v13.16b, v13.16b, v19.16b\n"
-    "sqadd v16.4s, v16.4s, v0.4s\n"
-    "uzp1 v13.16b, v13.16b, v13.16b\n"
-    "str d13, [x10, x15]\n"
-    "smax v25.4s, v25.4s, v24.4s\n"
-    "sqadd v21.4s, v21.4s, v29.4s\n"
-    "srshl v16.4s, v16.4s, v10.4s\n"
-    "and v3.16b, v23.16b, v10.16b\n"
-    "sshr v3.4s, v3.4s, #0x1f\n"
-    "uzp1 v17.16b, v17.16b, v25.16b\n"
-    "add v16.4s, v16.4s, v15.4s\n"
-    "srshl v21.4s, v21.4s, v18.4s\n"
-    "uzp1 v17.16b, v17.16b, v17.16b\n"
-    "str d17, [x9, x15]\n"
-    "smin v16.4s, v16.4s, v12.4s\n"
-    "sqrdmulh v20.4s, v20.4s, v11.4s\n"
-    "add v21.4s, v21.4s, v15.4s\n"
-    "sqadd v23.4s, v23.4s, v3.4s\n"
-    "smax v16.4s, v16.4s, v24.4s\n"
-    "smin v21.4s, v21.4s, v12.4s\n"
-    "and v25.16b, v20.16b, v18.16b\n"
-    "sshr v25.4s, v25.4s, #0x1f\n"
-    "smax v21.4s, v21.4s, v24.4s\n"
-    "srshl v23.4s, v23.4s, v10.4s\n"
-    "uzp1 v16.16b, v16.16b, v21.16b\n"
-    "add v23.4s, v23.4s, v15.4s\n"
-    "uzp1 v16.16b, v16.16b, v16.16b\n"
-    "str d16, [x28, x15]\n"
-    "smin v23.4s, v23.4s, v12.4s\n"
-    "sqadd v20.4s, v20.4s, v25.4s\n"
-    "smax v23.4s, v23.4s, v24.4s\n"
-    "srshl v20.4s, v20.4s, v18.4s\n"
-    "add v20.4s, v20.4s, v15.4s\n"
-    "smin v20.4s, v20.4s, v12.4s\n"
-    "smax v20.4s, v20.4s, v24.4s\n"
-    "uzp1 v23.16b, v23.16b, v20.16b\n"
-    "uzp1 v23.16b, v23.16b, v23.16b\n"
-    "str d23, [x27, x15]\n"
+    "smlal2 v26.4s, v29.8h, v3.8h\n"
+    "smlal2 v24.4s, v29.8h, v0.8h\n"
+    "ldr d29, [x20, x15]\n"
+    "smlal2 v23.4s, v28.8h, v2.8h\n"
+    "ssubl v29.8h, v29.8b, v22.8b\n"
+    "smlal v18.4s, v31.4h, v3.4h\n"
+    "smlal v9.4s, v30.4h, v5.4h\n"
+    "smlal v19.4s, v28.4h, v5.4h\n"
+    "smlal2 v11.4s, v28.8h, v5.8h\n"
+    "ldr d28, [x19, x15]\n"
+    "ssubl v28.8h, v28.8b, v22.8b\n"
+    "smlal2 v24.4s, v31.8h, v3.8h\n"
+    "smlal2 v23.4s, v30.8h, v5.8h\n"
     "add x15, x15, #0x8\n"
+    "smlal v18.4s, v29.4h, v7.4h\n"
+    "smlal v9.4s, v29.4h, v6.4h\n"
+    "smlal2 v24.4s, v29.8h, v7.8h\n"
+    "smlal2 v23.4s, v29.8h, v6.8h\n"
+    "smlal v13.4s, v31.4h, v6.4h\n"
+    "smlal v19.4s, v30.4h, v8.4h\n"
+    "sqdmulh v13.4s, v13.4s, v21.4s\n"
+    "smlal v18.4s, v28.4h, v8.4h\n"
+    "smlal v9.4s, v28.4h, v7.4h\n"
+    "sqdmulh v19.4s, v19.4s, v21.4s\n"
+    "smlal2 v26.4s, v31.8h, v6.8h\n"
+    "smlal2 v11.4s, v30.8h, v8.8h\n"
+    "sqdmulh v18.4s, v18.4s, v21.4s\n"
+    "smlal2 v24.4s, v28.8h, v8.8h\n"
+    "smlal2 v23.4s, v28.8h, v7.8h\n"
+    "sqdmulh v9.4s, v9.4s, v21.4s\n"
+    "and v7.16b, v13.16b, v25.16b\n"
+    "sqdmulh v26.4s, v26.4s, v10.4s\n"
+    "and v4.16b, v19.16b, v25.16b\n"
+    "sqdmulh v11.4s, v11.4s, v10.4s\n"
+    "and v21.16b, v18.16b, v25.16b\n"
+    "sqdmulh v24.4s, v24.4s, v10.4s\n"
+    "and v20.16b, v9.16b, v25.16b\n"
+    "sqdmulh v23.4s, v23.4s, v10.4s\n"
+    "sshr v7.4s, v7.4s, #0x1f\n"
+    "and v29.16b, v26.16b, v16.16b\n"
+    "sshr v4.4s, v4.4s, #0x1f\n"
+    "and v10.16b, v11.16b, v16.16b\n"
+    "sshr v21.4s, v21.4s, #0x1f\n"
+    "and v31.16b, v24.16b, v16.16b\n"
+    "sshr v20.4s, v20.4s, #0x1f\n"
+    "and v30.16b, v23.16b, v16.16b\n"
+    "sqadd v13.4s, v13.4s, v7.4s\n"
+    "sshr v29.4s, v29.4s, #0x1f\n"
+    "sqadd v19.4s, v19.4s, v4.4s\n"
+    "sshr v10.4s, v10.4s, #0x1f\n"
+    "sqadd v18.4s, v18.4s, v21.4s\n"
+    "sshr v31.4s, v31.4s, #0x1f\n"
+    "sqadd v9.4s, v9.4s, v20.4s\n"
+    "sshr v30.4s, v30.4s, #0x1f\n"
+    "srshl v13.4s, v13.4s, v25.4s\n"
+    "sqadd v26.4s, v26.4s, v29.4s\n"
+    "srshl v19.4s, v19.4s, v25.4s\n"
+    "sqadd v11.4s, v11.4s, v10.4s\n"
+    "srshl v18.4s, v18.4s, v25.4s\n"
+    "sqadd v24.4s, v24.4s, v31.4s\n"
+    "srshl v9.4s, v9.4s, v25.4s\n"
+    "sqadd v23.4s, v23.4s, v30.4s\n"
+    "srshl v26.4s, v26.4s, v16.4s\n"
+    "sqxtn v13.4h, v13.4s\n"
+    "srshl v11.4s, v11.4s, v16.4s\n"
+    "sqxtn v19.4h, v19.4s\n"
+    "srshl v24.4s, v24.4s, v16.4s\n"
+    "sqxtn v18.4h, v18.4s\n"
+    "srshl v23.4s, v23.4s, v16.4s\n"
+    "sqxtn v9.4h, v9.4s\n"
+    "sqxtn2 v13.8h, v26.4s\n"
+    "sqxtn2 v19.8h, v11.4s\n"
+    "sqxtn2 v18.8h, v24.4s\n"
+    "sqxtn2 v9.8h, v23.4s\n"
+    "sqadd v13.8h, v13.8h, v14.8h\n"
+    "sqadd v19.8h, v19.8h, v14.8h\n"
+    "sqadd v18.8h, v18.8h, v14.8h\n"
+    "sqadd v9.8h, v9.8h, v14.8h\n"
+    "smax v13.8h, v13.8h, v17.8h\n"
+    "smax v19.8h, v19.8h, v17.8h\n"
+    "smax v18.8h, v18.8h, v17.8h\n"
+    "smax v9.8h, v9.8h, v17.8h\n"
+    "smin v13.8h, v13.8h, v15.8h\n"
+    "smin v19.8h, v19.8h, v15.8h\n"
+    "smin v18.8h, v18.8h, v15.8h\n"
+    "smin v9.8h, v9.8h, v15.8h\n"
+    "uzp1 v13.16b, v13.16b, v13.16b\n"
+    "uzp1 v19.16b, v19.16b, v19.16b\n"
+    "str d13, [x10, x14]\n"
+    "uzp1 v18.16b, v18.16b, v18.16b\n"
+    "uzp1 v9.16b, v9.16b, v9.16b\n"
+    "str d19, [x9, x14]\n"
+    "str d18, [x28, x14]\n"
+    "str d9, [x27, x14]\n"
+    "add x14, x14, #0x8\n"
     "beq 64f\n"
-    "add x16, x16, #0x48\n"
+    "add x17, x17, #0x48\n"
     "3:"  // Oddments
     "ldr x19, [%x[params], %[offsetof_Params_bias]]\n"
     "tbz x8, #2, 5f\n"
     "ld1 { v13.4s }, [x19], #0x10\n"
     "tbz x8, #1, 4f\n"
-    "ld1 { v19.d }[0], [x19], #0x8\n"
+    "ld1 { v26.d }[0], [x19], #0x8\n"
     "tbz x8, #0, 7f\n"
-    "ld1 { v19.s }[2], [x19]\n"
+    "ld1 { v26.s }[2], [x19]\n"
     "b 7f\n"
     "4:"  // Oddments: Load bias: Bit 2: Bit 1: Unset
     "tbz x8, #0, 7f\n"
-    "ld1 { v19.s }[0], [x19]\n"
+    "ld1 { v26.s }[0], [x19]\n"
     "b 7f\n"
     "5:"  // Oddments: Load bias: Bit 2: Unset
     "tbz x8, #1, 6f\n"
@@ -609,38 +593,38 @@
     "tbz x8, #0, 7f\n"
     "ld1 { v13.s }[0], [x19]\n"
     "7:"  // Oddments: Load bias: Bit 2: End
-    "mov v17.16b, v13.16b\n"
-    "ldr d0, [x16, #0x0]\n"
-    "mov v25.16b, v19.16b\n"
-    "ldr d1, [x16, #0x8]\n"
-    "mov v16.16b, v13.16b\n"
-    "ldr d2, [x16, #0x10]\n"
-    "mov v21.16b, v19.16b\n"
-    "ldr d3, [x16, #0x18]\n"
-    "mov v23.16b, v13.16b\n"
-    "ldr d4, [x16, #0x20]\n"
-    "ssubl v0.8h, v0.8b, v9.8b\n"
-    "mov v20.16b, v19.16b\n"
-    "ldr d5, [x16, #0x28]\n"
-    "ssubl v1.8h, v1.8b, v9.8b\n"
-    "ldr d6, [x16, #0x30]\n"
-    "ssubl v2.8h, v2.8b, v9.8b\n"
-    "ldr d7, [x16, #0x38]\n"
-    "ssubl v3.8h, v3.8b, v9.8b\n"
-    "ldr d8, [x16, #0x40]\n"
-    "ssubl v4.8h, v4.8b, v9.8b\n"
-    "ldp x23, x22, [x14, #0x0]\n"
-    "ssubl v5.8h, v5.8b, v9.8b\n"
-    "ldp x21, x20, [x14, #0x10]\n"
-    "ssubl v6.8h, v6.8b, v9.8b\n"
-    "ssubl v7.8h, v7.8b, v9.8b\n"
-    "ldr x19, [x14, #0x20]\n"
-    "ssubl v8.8h, v8.8b, v9.8b\n"
-    "add x23, x23, x17\n"
-    "add x22, x22, x17\n"
-    "add x21, x21, x17\n"
-    "add x20, x20, x17\n"
-    "add x19, x19, x17\n"
+    "ldr d0, [x17, #0x0]\n"
+    "ldr d1, [x17, #0x8]\n"
+    "mov v19.16b, v13.16b\n"
+    "mov v11.16b, v26.16b\n"
+    "ldr d2, [x17, #0x10]\n"
+    "ldr d3, [x17, #0x18]\n"
+    "mov v18.16b, v13.16b\n"
+    "mov v24.16b, v26.16b\n"
+    "ldr d4, [x17, #0x20]\n"
+    "ldr d5, [x17, #0x28]\n"
+    "mov v9.16b, v13.16b\n"
+    "mov v23.16b, v26.16b\n"
+    "ldr d6, [x17, #0x30]\n"
+    "ldr d7, [x17, #0x38]\n"
+    "ssubl v0.8h, v0.8b, v12.8b\n"
+    "ssubl v1.8h, v1.8b, v12.8b\n"
+    "ldr d8, [x17, #0x40]\n"
+    "ldp x23, x22, [x12, #0x0]\n"
+    "ssubl v2.8h, v2.8b, v12.8b\n"
+    "ssubl v3.8h, v3.8b, v12.8b\n"
+    "ldp x21, x20, [x12, #0x10]\n"
+    "ldr x19, [x12, #0x20]\n"
+    "ssubl v4.8h, v4.8b, v12.8b\n"
+    "ssubl v5.8h, v5.8b, v12.8b\n"
+    "ssubl v6.8h, v6.8b, v12.8b\n"
+    "ssubl v7.8h, v7.8b, v12.8b\n"
+    "ssubl v8.8h, v8.8b, v12.8b\n"
+    "add x23, x23, x15\n"
+    "add x22, x22, x15\n"
+    "add x21, x21, x15\n"
+    "add x20, x20, x15\n"
+    "add x19, x19, x15\n"
     "tbz x8, #2, 9f\n"
     "ld1 { v31.s }[0], [x23], #0x4\n"
     "ld1 { v30.s }[0], [x22], #0x4\n"
@@ -690,33 +674,33 @@
     "ld1 { v28.b }[0], [x20]\n"
     "ld1 { v27.b }[0], [x19]\n"
     "11:"  // Oddments: Initial loads: Bit 2: End
-    "ldr x21, [x14, #0x28]\n"
-    "ssubl v31.8h, v31.8b, v14.8b\n"
+    "ssubl v31.8h, v31.8b, v22.8b\n"
     "smlal v13.4s, v31.4h, v4.4h\n"
-    "ssubl v30.8h, v30.8b, v14.8b\n"
-    "smlal2 v19.4s, v31.8h, v4.8h\n"
-    "ssubl v29.8h, v29.8b, v14.8b\n"
-    "smlal v17.4s, v31.4h, v3.4h\n"
-    "ssubl v28.8h, v28.8b, v14.8b\n"
-    "smlal2 v25.4s, v31.8h, v3.8h\n"
-    "ssubl v27.8h, v27.8b, v14.8b\n"
-    "smlal v16.4s, v31.4h, v1.4h\n"
-    "add x21, x21, x17\n"
-    "smlal2 v21.4s, v31.8h, v1.8h\n"
-    "smlal v23.4s, v31.4h, v0.4h\n"
-    "smlal2 v20.4s, v31.8h, v0.8h\n"
+    "smlal2 v26.4s, v31.8h, v4.8h\n"
+    "ldr x21, [x12, #0x28]\n"
+    "smlal v19.4s, v31.4h, v3.4h\n"
+    "smlal2 v11.4s, v31.8h, v3.8h\n"
+    "ssubl v30.8h, v30.8b, v22.8b\n"
+    "add x21, x21, x15\n"
+    "ssubl v29.8h, v29.8b, v22.8b\n"
+    "smlal v18.4s, v31.4h, v1.4h\n"
+    "smlal2 v24.4s, v31.8h, v1.8h\n"
+    "smlal v9.4s, v31.4h, v0.4h\n"
+    "smlal2 v23.4s, v31.8h, v0.8h\n"
+    "ssubl v28.8h, v28.8b, v22.8b\n"
     "smlal v13.4s, v30.4h, v0.4h\n"
-    "smlal2 v19.4s, v30.8h, v0.8h\n"
-    "smlal v17.4s, v29.4h, v2.4h\n"
-    "smlal2 v25.4s, v29.8h, v2.8h\n"
+    "smlal2 v26.4s, v30.8h, v0.8h\n"
+    "ssubl v27.8h, v27.8b, v22.8b\n"
+    "smlal v19.4s, v29.4h, v2.4h\n"
+    "smlal2 v11.4s, v29.8h, v2.8h\n"
     "smlal v13.4s, v28.4h, v5.4h\n"
-    "smlal2 v19.4s, v28.8h, v5.8h\n"
-    "smlal v17.4s, v28.4h, v4.4h\n"
-    "smlal2 v25.4s, v28.8h, v4.8h\n"
-    "smlal v16.4s, v28.4h, v2.4h\n"
-    "smlal2 v21.4s, v28.8h, v2.8h\n"
-    "smlal v23.4s, v28.4h, v1.4h\n"
-    "smlal2 v20.4s, v28.8h, v1.8h\n"
+    "smlal2 v26.4s, v28.8h, v5.8h\n"
+    "smlal v19.4s, v28.4h, v4.4h\n"
+    "smlal2 v11.4s, v28.8h, v4.8h\n"
+    "smlal v18.4s, v28.4h, v2.4h\n"
+    "smlal2 v24.4s, v28.8h, v2.8h\n"
+    "smlal v9.4s, v28.4h, v1.4h\n"
+    "smlal2 v23.4s, v28.8h, v1.8h\n"
     "tbz x8, #2, 13f\n"
     "ld1 { v31.s }[0], [x21], #0x4\n"
     "tbz x8, #1, 12f\n"
@@ -738,19 +722,19 @@
     "tbz x8, #0, 15f\n"
     "ld1 { v31.b }[0], [x21]\n"
     "15:"  // Oddments: Load (3, 0): Bit 2: End
+    "ssubl v31.8h, v31.8b, v22.8b\n"
+    "smlal v18.4s, v31.4h, v6.4h\n"
+    "smlal2 v24.4s, v31.8h, v6.8h\n"
+    "ldr x20, [x12, #0x30]\n"
     "smlal v13.4s, v27.4h, v7.4h\n"
-    "ldr x20, [x14, #0x30]\n"
-    "ssubl v31.8h, v31.8b, v14.8b\n"
-    "smlal2 v19.4s, v27.8h, v7.8h\n"
-    "smlal v17.4s, v27.4h, v6.4h\n"
-    "add x20, x20, x17\n"
-    "smlal2 v25.4s, v27.8h, v6.8h\n"
-    "smlal v23.4s, v27.4h, v3.4h\n"
-    "smlal2 v20.4s, v27.8h, v3.8h\n"
-    "smlal v16.4s, v31.4h, v6.4h\n"
-    "smlal2 v21.4s, v31.8h, v6.8h\n"
-    "smlal v16.4s, v27.4h, v4.4h\n"
-    "smlal2 v21.4s, v27.8h, v4.8h\n"
+    "smlal2 v26.4s, v27.8h, v7.8h\n"
+    "add x20, x20, x15\n"
+    "smlal v19.4s, v27.4h, v6.4h\n"
+    "smlal2 v11.4s, v27.8h, v6.8h\n"
+    "smlal v18.4s, v27.4h, v4.4h\n"
+    "smlal2 v24.4s, v27.8h, v4.8h\n"
+    "smlal v9.4s, v27.4h, v3.4h\n"
+    "smlal2 v23.4s, v27.8h, v3.8h\n"
     "tbz x8, #2, 17f\n"
     "ld1 { v29.s }[0], [x20], #0x4\n"
     "tbz x8, #1, 16f\n"
@@ -772,11 +756,11 @@
     "tbz x8, #0, 19f\n"
     "ld1 { v29.b }[0], [x20]\n"
     "19:"  // Oddments: Load (3, 3): Bit 2: End
-    "ldr x26, [x14, #0x38]\n"
-    "ssubl v29.8h, v29.8b, v14.8b\n"
-    "smlal v23.4s, v29.4h, v8.4h\n"
-    "smlal2 v20.4s, v29.8h, v8.8h\n"
-    "add x26, x26, x17\n"
+    "ssubl v29.8h, v29.8b, v22.8b\n"
+    "ldr x26, [x12, #0x38]\n"
+    "smlal v9.4s, v29.4h, v8.4h\n"
+    "smlal2 v23.4s, v29.8h, v8.8h\n"
+    "add x26, x26, x15\n"
     "tbz x8, #2, 21f\n"
     "ld1 { v28.s }[0], [x26], #0x4\n"
     "tbz x8, #1, 20f\n"
@@ -798,13 +782,13 @@
     "tbz x8, #0, 23f\n"
     "ld1 { v28.b }[0], [x26]\n"
     "23:"  // Oddments: Load (0, 1): Bit 2: End
-    "ldr x25, [x14, #0x40]\n"
-    "ssubl v28.8h, v28.8b, v14.8b\n"
+    "ssubl v28.8h, v28.8b, v22.8b\n"
+    "ldr x25, [x12, #0x40]\n"
     "smlal v13.4s, v28.4h, v1.4h\n"
-    "smlal2 v19.4s, v28.8h, v1.8h\n"
-    "add x25, x25, x17\n"
-    "smlal v17.4s, v28.4h, v0.4h\n"
-    "smlal2 v25.4s, v28.8h, v0.8h\n"
+    "smlal2 v26.4s, v28.8h, v1.8h\n"
+    "smlal v19.4s, v28.4h, v0.4h\n"
+    "smlal2 v11.4s, v28.8h, v0.8h\n"
+    "add x25, x25, x15\n"
     "tbz x8, #2, 25f\n"
     "ld1 { v31.s }[0], [x25], #0x4\n"
     "tbz x8, #1, 24f\n"
@@ -826,13 +810,13 @@
     "tbz x8, #0, 27f\n"
     "ld1 { v31.b }[0], [x25]\n"
     "27:"  // Oddments: Load (0, 2): Bit 2: End
-    "ldr x19, [x14, #0x48]\n"
-    "ssubl v31.8h, v31.8b, v14.8b\n"
+    "ssubl v31.8h, v31.8b, v22.8b\n"
+    "ldr x19, [x12, #0x48]\n"
     "smlal v13.4s, v31.4h, v2.4h\n"
-    "smlal2 v19.4s, v31.8h, v2.8h\n"
-    "add x19, x19, x17\n"
-    "smlal v17.4s, v31.4h, v1.4h\n"
-    "smlal2 v25.4s, v31.8h, v1.8h\n"
+    "smlal2 v26.4s, v31.8h, v2.8h\n"
+    "smlal v19.4s, v31.4h, v1.4h\n"
+    "smlal2 v11.4s, v31.8h, v1.8h\n"
+    "add x19, x19, x15\n"
     "tbz x8, #2, 29f\n"
     "ld1 { v30.s }[0], [x19], #0x4\n"
     "tbz x8, #1, 28f\n"
@@ -854,17 +838,17 @@
     "tbz x8, #0, 31f\n"
     "ld1 { v30.b }[0], [x19]\n"
     "31:"  // Oddments: Load (2, 2): Bit 2: End
-    "ldr x24, [x14, #0x50]\n"
-    "ssubl v30.8h, v30.8b, v14.8b\n"
+    "ssubl v30.8h, v30.8b, v22.8b\n"
+    "ldr x24, [x12, #0x50]\n"
     "smlal v13.4s, v30.4h, v8.4h\n"
-    "smlal2 v19.4s, v30.8h, v8.8h\n"
-    "add x24, x24, x17\n"
-    "smlal v17.4s, v30.4h, v7.4h\n"
-    "smlal2 v25.4s, v30.8h, v7.8h\n"
-    "smlal v16.4s, v30.4h, v5.4h\n"
-    "smlal2 v21.4s, v30.8h, v5.8h\n"
-    "smlal v23.4s, v30.4h, v4.4h\n"
-    "smlal2 v20.4s, v30.8h, v4.8h\n"
+    "smlal2 v26.4s, v30.8h, v8.8h\n"
+    "smlal v19.4s, v30.4h, v7.4h\n"
+    "smlal2 v11.4s, v30.8h, v7.8h\n"
+    "add x24, x24, x15\n"
+    "smlal v18.4s, v30.4h, v5.4h\n"
+    "smlal2 v24.4s, v30.8h, v5.8h\n"
+    "smlal v9.4s, v30.4h, v4.4h\n"
+    "smlal2 v23.4s, v30.8h, v4.8h\n"
     "tbz x8, #2, 33f\n"
     "ld1 { v29.s }[0], [x24], #0x4\n"
     "tbz x8, #1, 32f\n"
@@ -886,13 +870,13 @@
     "tbz x8, #0, 35f\n"
     "ld1 { v29.b }[0], [x24]\n"
     "35:"  // Oddments: Load (1, 0): Bit 2: End
-    "ldr x23, [x14, #0x58]\n"
-    "ssubl v29.8h, v29.8b, v14.8b\n"
+    "ssubl v29.8h, v29.8b, v22.8b\n"
+    "ldr x23, [x12, #0x58]\n"
     "smlal v13.4s, v29.4h, v3.4h\n"
-    "smlal2 v19.4s, v29.8h, v3.8h\n"
-    "add x23, x23, x17\n"
-    "smlal v16.4s, v29.4h, v0.4h\n"
-    "smlal2 v21.4s, v29.8h, v0.8h\n"
+    "smlal2 v26.4s, v29.8h, v3.8h\n"
+    "smlal v18.4s, v29.4h, v0.4h\n"
+    "smlal2 v24.4s, v29.8h, v0.8h\n"
+    "add x23, x23, x15\n"
     "tbz x8, #2, 37f\n"
     "ld1 { v28.s }[0], [x23], #0x4\n"
     "tbz x8, #1, 36f\n"
@@ -914,13 +898,13 @@
     "tbz x8, #0, 39f\n"
     "ld1 { v28.b }[0], [x23]\n"
     "39:"  // Oddments: Load (1, 3): Bit 2: End
-    "ldr x22, [x14, #0x60]\n"
-    "ssubl v28.8h, v28.8b, v14.8b\n"
-    "smlal v17.4s, v28.4h, v5.4h\n"
-    "smlal2 v25.4s, v28.8h, v5.8h\n"
-    "add x22, x22, x17\n"
-    "smlal v23.4s, v28.4h, v2.4h\n"
-    "smlal2 v20.4s, v28.8h, v2.8h\n"
+    "ssubl v28.8h, v28.8b, v22.8b\n"
+    "ldr x22, [x12, #0x60]\n"
+    "smlal v19.4s, v28.4h, v5.4h\n"
+    "smlal2 v11.4s, v28.8h, v5.8h\n"
+    "smlal v9.4s, v28.4h, v2.4h\n"
+    "smlal2 v23.4s, v28.8h, v2.8h\n"
+    "add x22, x22, x15\n"
     "tbz x8, #2, 41f\n"
     "ld1 { v31.s }[0], [x22], #0x4\n"
     "tbz x8, #1, 40f\n"
@@ -942,13 +926,13 @@
     "tbz x8, #0, 43f\n"
     "ld1 { v31.b }[0], [x22]\n"
     "43:"  // Oddments: Load (2, 0): Bit 2: End
-    "ldr x21, [x14, #0x68]\n"
-    "ssubl v31.8h, v31.8b, v14.8b\n"
+    "ssubl v31.8h, v31.8b, v22.8b\n"
+    "ldr x21, [x12, #0x68]\n"
     "smlal v13.4s, v31.4h, v6.4h\n"
-    "smlal2 v19.4s, v31.8h, v6.8h\n"
-    "add x21, x21, x17\n"
-    "smlal v16.4s, v31.4h, v3.4h\n"
-    "smlal2 v21.4s, v31.8h, v3.8h\n"
+    "smlal2 v26.4s, v31.8h, v6.8h\n"
+    "smlal v18.4s, v31.4h, v3.4h\n"
+    "smlal2 v24.4s, v31.8h, v3.8h\n"
+    "add x21, x21, x15\n"
     "tbz x8, #2, 45f\n"
     "ld1 { v30.s }[0], [x21], #0x4\n"
     "tbz x8, #1, 44f\n"
@@ -970,13 +954,13 @@
     "tbz x8, #0, 47f\n"
     "ld1 { v30.b }[0], [x21]\n"
     "47:"  // Oddments: Load (2, 3): Bit 2: End
-    "ldr x20, [x14, #0x70]\n"
-    "ssubl v30.8h, v30.8b, v14.8b\n"
-    "smlal v17.4s, v30.4h, v8.4h\n"
-    "smlal2 v25.4s, v30.8h, v8.8h\n"
-    "add x20, x20, x17\n"
-    "smlal v23.4s, v30.4h, v5.4h\n"
-    "smlal2 v20.4s, v30.8h, v5.8h\n"
+    "ssubl v30.8h, v30.8b, v22.8b\n"
+    "ldr x20, [x12, #0x70]\n"
+    "smlal v19.4s, v30.4h, v8.4h\n"
+    "smlal2 v11.4s, v30.8h, v8.8h\n"
+    "smlal v9.4s, v30.4h, v5.4h\n"
+    "smlal2 v23.4s, v30.8h, v5.8h\n"
+    "add x20, x20, x15\n"
     "tbz x8, #2, 49f\n"
     "ld1 { v29.s }[0], [x20], #0x4\n"
     "tbz x8, #1, 48f\n"
@@ -998,13 +982,13 @@
     "tbz x8, #0, 51f\n"
     "ld1 { v29.b }[0], [x20]\n"
     "51:"  // Oddments: Load (3, 1): Bit 2: End
-    "ldr x19, [x14, #0x78]\n"
-    "ssubl v29.8h, v29.8b, v14.8b\n"
-    "smlal v16.4s, v29.4h, v7.4h\n"
-    "smlal2 v21.4s, v29.8h, v7.8h\n"
-    "add x19, x19, x17\n"
-    "smlal v23.4s, v29.4h, v6.4h\n"
-    "smlal2 v20.4s, v29.8h, v6.8h\n"
+    "ssubl v29.8h, v29.8b, v22.8b\n"
+    "ldr x19, [x12, #0x78]\n"
+    "smlal v18.4s, v29.4h, v7.4h\n"
+    "smlal2 v24.4s, v29.8h, v7.8h\n"
+    "smlal v9.4s, v29.4h, v6.4h\n"
+    "smlal2 v23.4s, v29.8h, v6.8h\n"
+    "add x19, x19, x15\n"
     "tbz x8, #2, 53f\n"
     "ld1 { v28.s }[0], [x19], #0x4\n"
     "tbz x8, #1, 52f\n"
@@ -1026,160 +1010,150 @@
     "tbz x8, #0, 55f\n"
     "ld1 { v28.b }[0], [x19]\n"
     "55:"  // Oddments: Load (3, 2): Bit 2: End
-    "ssubl v28.8h, v28.8b, v14.8b\n"
-    "smlal v16.4s, v28.4h, v8.4h\n"
-    "smlal2 v21.4s, v28.8h, v8.8h\n"
-    "smlal v23.4s, v28.4h, v7.4h\n"
-    "smlal2 v20.4s, v28.8h, v7.8h\n"
+    "ssubl v28.8h, v28.8b, v22.8b\n"
+    "smlal v18.4s, v28.4h, v8.4h\n"
+    "smlal2 v24.4s, v28.8h, v8.8h\n"
+    "smlal v9.4s, v28.4h, v7.4h\n"
+    "smlal2 v23.4s, v28.8h, v7.8h\n"
     "tbz x8, #2, 57f\n"
-    "ld1 { v26.4s }, [x13], #0x10\n"
-    "ld1 { v10.4s }, [x11], #0x10\n"
+    "ld1 { v21.4s }, [x13], #0x10\n"
+    "ld1 { v25.4s }, [x11], #0x10\n"
     "tbz x8, #1, 56f\n"
-    "ld1 { v11.d }[0], [x13], #0x8\n"
-    "ld1 { v18.d }[0], [x11], #0x8\n"
+    "ld1 { v10.d }[0], [x13], #0x8\n"
+    "ld1 { v16.d }[0], [x11], #0x8\n"
     "tbz x8, #0, 59f\n"
-    "ld1 { v11.s }[2], [x13]\n"
-    "ld1 { v18.s }[2], [x11]\n"
+    "ld1 { v10.s }[2], [x13]\n"
+    "ld1 { v16.s }[2], [x11]\n"
     "b 59f\n"
     "56:"  // Oddments: Load requant params: Bit 2: Bit 1: Unset
     "tbz x8, #0, 59f\n"
-    "ld1 { v11.s }[0], [x13]\n"
-    "ld1 { v18.s }[0], [x11]\n"
+    "ld1 { v10.s }[0], [x13]\n"
+    "ld1 { v16.s }[0], [x11]\n"
     "b 59f\n"
     "57:"  // Oddments: Load requant params: Bit 2: Unset
     "tbz x8, #1, 58f\n"
-    "ld1 { v26.d }[0], [x13], #0x8\n"
-    "ld1 { v10.d }[0], [x11], #0x8\n"
+    "ld1 { v21.d }[0], [x13], #0x8\n"
+    "ld1 { v25.d }[0], [x11], #0x8\n"
     "tbz x8, #0, 59f\n"
-    "ld1 { v26.s }[2], [x13]\n"
-    "ld1 { v10.s }[2], [x11]\n"
+    "ld1 { v21.s }[2], [x13]\n"
+    "ld1 { v25.s }[2], [x11]\n"
     "b 59f\n"
     "58:"  // Oddments: Load requant params: Bit 2: Unset: Bit 1: Unset
     "tbz x8, #0, 59f\n"
-    "ld1 { v26.s }[0], [x13]\n"
-    "ld1 { v10.s }[0], [x11]\n"
+    "ld1 { v21.s }[0], [x13]\n"
+    "ld1 { v25.s }[0], [x11]\n"
     "59:"  // Oddments: Load requant params: Bit 2: End
-    "sqrdmulh v13.4s, v13.4s, v26.4s\n"
-    "add x10, x10, x15\n"
-    "sqrdmulh v19.4s, v19.4s, v11.4s\n"
-    "add x9, x9, x15\n"
-    "sqrdmulh v17.4s, v17.4s, v26.4s\n"
-    "add x28, x28, x15\n"
-    "sqrdmulh v25.4s, v25.4s, v11.4s\n"
-    "add x27, x27, x15\n"
-    "sqrdmulh v16.4s, v16.4s, v26.4s\n"
-    "and v22.16b, v13.16b, v10.16b\n"
-    "sshr v22.4s, v22.4s, #0x1f\n"
-    "and v28.16b, v19.16b, v18.16b\n"
-    "and v3.16b, v17.16b, v10.16b\n"
-    "sshr v28.4s, v28.4s, #0x1f\n"
-    "and v6.16b, v25.16b, v18.16b\n"
-    "and v0.16b, v16.16b, v10.16b\n"
-    "sshr v3.4s, v3.4s, #0x1f\n"
-    "sqrdmulh v21.4s, v21.4s, v11.4s\n"
-    "sshr v6.4s, v6.4s, #0x1f\n"
-    "sqadd v13.4s, v13.4s, v22.4s\n"
-    "sqrdmulh v23.4s, v23.4s, v26.4s\n"
-    "sshr v0.4s, v0.4s, #0x1f\n"
-    "sqrdmulh v20.4s, v20.4s, v11.4s\n"
-    "sqadd v19.4s, v19.4s, v28.4s\n"
-    "sqadd v17.4s, v17.4s, v3.4s\n"
-    "srshl v13.4s, v13.4s, v10.4s\n"
-    "sqadd v25.4s, v25.4s, v6.4s\n"
-    "srshl v19.4s, v19.4s, v18.4s\n"
-    "srshl v17.4s, v17.4s, v10.4s\n"
-    "add v13.4s, v13.4s, v15.4s\n"
-    "srshl v25.4s, v25.4s, v18.4s\n"
-    "add v19.4s, v19.4s, v15.4s\n"
-    "smin v13.4s, v13.4s, v12.4s\n"
-    "add v17.4s, v17.4s, v15.4s\n"
-    "smin v19.4s, v19.4s, v12.4s\n"
-    "smax v13.4s, v13.4s, v24.4s\n"
-    "smin v17.4s, v17.4s, v12.4s\n"
-    "smax v19.4s, v19.4s, v24.4s\n"
-    "add v25.4s, v25.4s, v15.4s\n"
-    "smax v17.4s, v17.4s, v24.4s\n"
-    "uzp1 v13.16b, v13.16b, v19.16b\n"
-    "smin v25.4s, v25.4s, v12.4s\n"
-    "uzp1 v13.16b, v13.16b, v13.16b\n"
-    "sqadd v16.4s, v16.4s, v0.4s\n"
-    "smax v25.4s, v25.4s, v24.4s\n"
-    "and v29.16b, v21.16b, v18.16b\n"
+    "sqdmulh v13.4s, v13.4s, v21.4s\n"
+    "sqdmulh v19.4s, v19.4s, v21.4s\n"
+    "add x10, x10, x14\n"
+    "add x9, x9, x14\n"
+    "sqdmulh v18.4s, v18.4s, v21.4s\n"
+    "sqdmulh v9.4s, v9.4s, v21.4s\n"
+    "add x28, x28, x14\n"
+    "add x27, x27, x14\n"
+    "and v7.16b, v13.16b, v25.16b\n"
+    "sqdmulh v26.4s, v26.4s, v10.4s\n"
+    "and v4.16b, v19.16b, v25.16b\n"
+    "sqdmulh v11.4s, v11.4s, v10.4s\n"
+    "and v21.16b, v18.16b, v25.16b\n"
+    "sqdmulh v24.4s, v24.4s, v10.4s\n"
+    "and v20.16b, v9.16b, v25.16b\n"
+    "sqdmulh v23.4s, v23.4s, v10.4s\n"
+    "sshr v7.4s, v7.4s, #0x1f\n"
+    "and v29.16b, v26.16b, v16.16b\n"
+    "sshr v4.4s, v4.4s, #0x1f\n"
+    "and v10.16b, v11.16b, v16.16b\n"
+    "sshr v21.4s, v21.4s, #0x1f\n"
+    "and v31.16b, v24.16b, v16.16b\n"
+    "sshr v20.4s, v20.4s, #0x1f\n"
+    "and v30.16b, v23.16b, v16.16b\n"
+    "sqadd v13.4s, v13.4s, v7.4s\n"
     "sshr v29.4s, v29.4s, #0x1f\n"
-    "uzp1 v17.16b, v17.16b, v25.16b\n"
-    "srshl v16.4s, v16.4s, v10.4s\n"
-    "and v3.16b, v23.16b, v10.16b\n"
-    "sshr v3.4s, v3.4s, #0x1f\n"
-    "uzp1 v17.16b, v17.16b, v17.16b\n"
-    "add v16.4s, v16.4s, v15.4s\n"
-    "sqadd v21.4s, v21.4s, v29.4s\n"
-    "and v25.16b, v20.16b, v18.16b\n"
-    "sshr v25.4s, v25.4s, #0x1f\n"
-    "smin v16.4s, v16.4s, v12.4s\n"
-    "srshl v21.4s, v21.4s, v18.4s\n"
-    "sqadd v23.4s, v23.4s, v3.4s\n"
-    "smax v16.4s, v16.4s, v24.4s\n"
-    "add v21.4s, v21.4s, v15.4s\n"
-    "srshl v23.4s, v23.4s, v10.4s\n"
-    "sqadd v20.4s, v20.4s, v25.4s\n"
-    "smin v21.4s, v21.4s, v12.4s\n"
-    "add v23.4s, v23.4s, v15.4s\n"
-    "srshl v20.4s, v20.4s, v18.4s\n"
-    "smax v21.4s, v21.4s, v24.4s\n"
-    "smin v23.4s, v23.4s, v12.4s\n"
-    "uzp1 v16.16b, v16.16b, v21.16b\n"
-    "add v20.4s, v20.4s, v15.4s\n"
-    "uzp1 v16.16b, v16.16b, v16.16b\n"
-    "smax v23.4s, v23.4s, v24.4s\n"
-    "smin v20.4s, v20.4s, v12.4s\n"
-    "smax v20.4s, v20.4s, v24.4s\n"
-    "uzp1 v23.16b, v23.16b, v20.16b\n"
-    "uzp1 v23.16b, v23.16b, v23.16b\n"
+    "sqadd v19.4s, v19.4s, v4.4s\n"
+    "sshr v10.4s, v10.4s, #0x1f\n"
+    "sqadd v18.4s, v18.4s, v21.4s\n"
+    "sshr v31.4s, v31.4s, #0x1f\n"
+    "sqadd v9.4s, v9.4s, v20.4s\n"
+    "sshr v30.4s, v30.4s, #0x1f\n"
+    "srshl v13.4s, v13.4s, v25.4s\n"
+    "sqadd v26.4s, v26.4s, v29.4s\n"
+    "srshl v19.4s, v19.4s, v25.4s\n"
+    "sqadd v11.4s, v11.4s, v10.4s\n"
+    "srshl v18.4s, v18.4s, v25.4s\n"
+    "sqadd v24.4s, v24.4s, v31.4s\n"
+    "srshl v9.4s, v9.4s, v25.4s\n"
+    "sqadd v23.4s, v23.4s, v30.4s\n"
+    "srshl v26.4s, v26.4s, v16.4s\n"
+    "sqxtn v13.4h, v13.4s\n"
+    "srshl v11.4s, v11.4s, v16.4s\n"
+    "sqxtn v19.4h, v19.4s\n"
+    "srshl v24.4s, v24.4s, v16.4s\n"
+    "sqxtn v18.4h, v18.4s\n"
+    "srshl v23.4s, v23.4s, v16.4s\n"
+    "sqxtn v9.4h, v9.4s\n"
+    "sqxtn2 v13.8h, v26.4s\n"
+    "sqxtn2 v19.8h, v11.4s\n"
+    "sqxtn2 v18.8h, v24.4s\n"
+    "sqxtn2 v9.8h, v23.4s\n"
+    "sqadd v13.8h, v13.8h, v14.8h\n"
+    "sqadd v19.8h, v19.8h, v14.8h\n"
+    "sqadd v18.8h, v18.8h, v14.8h\n"
+    "sqadd v9.8h, v9.8h, v14.8h\n"
+    "smax v13.8h, v13.8h, v17.8h\n"
+    "smax v19.8h, v19.8h, v17.8h\n"
+    "smax v18.8h, v18.8h, v17.8h\n"
+    "smax v9.8h, v9.8h, v17.8h\n"
+    "smin v13.8h, v13.8h, v15.8h\n"
+    "smin v19.8h, v19.8h, v15.8h\n"
+    "smin v18.8h, v18.8h, v15.8h\n"
+    "smin v9.8h, v9.8h, v15.8h\n"
+    "uzp1 v13.16b, v13.16b, v13.16b\n"
+    "uzp1 v19.16b, v19.16b, v19.16b\n"
+    "uzp1 v18.16b, v18.16b, v18.16b\n"
+    "uzp1 v9.16b, v9.16b, v9.16b\n"
     "tbz x8, #2, 61f\n"
     "st1 { v13.s }[0], [x10], #0x4\n"
-    "st1 { v17.s }[0], [x9], #0x4\n"
-    "st1 { v16.s }[0], [x28], #0x4\n"
-    "st1 { v23.s }[0], [x27], #0x4\n"
+    "st1 { v19.s }[0], [x9], #0x4\n"
+    "st1 { v18.s }[0], [x28], #0x4\n"
+    "st1 { v9.s }[0], [x27], #0x4\n"
     "tbz x8, #1, 60f\n"
     "st1 { v13.h }[2], [x10], #0x2\n"
-    "st1 { v17.h }[2], [x9], #0x2\n"
-    "st1 { v16.h }[2], [x28], #0x2\n"
-    "st1 { v23.h }[2], [x27], #0x2\n"
+    "st1 { v19.h }[2], [x9], #0x2\n"
+    "st1 { v18.h }[2], [x28], #0x2\n"
+    "st1 { v9.h }[2], [x27], #0x2\n"
     "tbz x8, #0, 63f\n"
     "st1 { v13.b }[6], [x10], #0x1\n"
-    "st1 { v17.b }[6], [x9], #0x1\n"
-    "st1 { v16.b }[6], [x28], #0x1\n"
-    "st1 { v23.b }[6], [x27], #0x1\n"
+    "st1 { v19.b }[6], [x9], #0x1\n"
+    "st1 { v18.b }[6], [x28], #0x1\n"
+    "st1 { v9.b }[6], [x27], #0x1\n"
     "b 63f\n"
     "60:"  // Oddments: Bit 2: Bit 1: Unset
     "tbz x8, #0, 63f\n"
     "st1 { v13.b }[4], [x10], #0x1\n"
-    "st1 { v17.b }[4], [x9], #0x1\n"
-    "st1 { v16.b }[4], [x28], #0x1\n"
-    "st1 { v23.b }[4], [x27], #0x1\n"
+    "st1 { v19.b }[4], [x9], #0x1\n"
+    "st1 { v18.b }[4], [x28], #0x1\n"
+    "st1 { v9.b }[4], [x27], #0x1\n"
     "b 63f\n"
     "61:"  // Oddments: Bit 2: Unset
     "tbz x8, #1, 62f\n"
     "st1 { v13.h }[0], [x10], #0x2\n"
-    "st1 { v17.h }[0], [x9], #0x2\n"
-    "st1 { v16.h }[0], [x28], #0x2\n"
-    "st1 { v23.h }[0], [x27], #0x2\n"
+    "st1 { v19.h }[0], [x9], #0x2\n"
+    "st1 { v18.h }[0], [x28], #0x2\n"
+    "st1 { v9.h }[0], [x27], #0x2\n"
     "tbz x8, #0, 63f\n"
     "st1 { v13.b }[2], [x10], #0x1\n"
-    "st1 { v17.b }[2], [x9], #0x1\n"
-    "st1 { v16.b }[2], [x28], #0x1\n"
-    "st1 { v23.b }[2], [x27], #0x1\n"
+    "st1 { v19.b }[2], [x9], #0x1\n"
+    "st1 { v18.b }[2], [x28], #0x1\n"
+    "st1 { v9.b }[2], [x27], #0x1\n"
     "b 63f\n"
     "62:"  // Oddments: Bit 2: Unset: Bit 1: Unset
     "tbz x8, #0, 63f\n"
     "st1 { v13.b }[0], [x10], #0x1\n"
-    "st1 { v17.b }[0], [x9], #0x1\n"
-    "st1 { v16.b }[0], [x28], #0x1\n"
-    "st1 { v23.b }[0], [x27], #0x1\n"
+    "st1 { v19.b }[0], [x9], #0x1\n"
+    "st1 { v18.b }[0], [x28], #0x1\n"
+    "st1 { v9.b }[0], [x27], #0x1\n"
     "63:"  // Oddments: Bit 2: End
-
     "64:"  // End
-
     :
     : [offsetof_Params_bias] "I" (offsetof(Params, bias)), [offsetof_Params_inptrs] "I" (offsetof(Params, inptrs)), [offsetof_Params_n_channels] "I" (offsetof(Params, n_channels)), [offsetof_Params_outptrs] "I" (offsetof(Params, outptrs)), [offsetof_Params_requant] "I" (offsetof(Params, requant)), [offsetof_Params_requant_muls] "I" (offsetof(Params, requant_muls)), [offsetof_Params_requant_shifts] "I" (offsetof(Params, requant_shifts)), [offsetof_Params_weights] "I" (offsetof(Params, weights)), [offsetof_Requantize32_a_offset] "I" (offsetof(arm_gemm::Requantize32, a_offset)), [offsetof_Requantize32_b_offset] "I" (offsetof(arm_gemm::Requantize32, b_offset)), [offsetof_Requantize32_c_offset] "I" (offsetof(arm_gemm::Requantize32, c_offset)), [offsetof_Requantize32_maxval] "I" (offsetof(arm_gemm::Requantize32, maxval)), [offsetof_Requantize32_minval] "I" (offsetof(arm_gemm::Requantize32, minval)), [params] "r" (&params)
     : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x8", "x9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17", "x19", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_s8q_nhwc_3x3_s2_output2x2_mla_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_s8q_nhwc_3x3_s2_output2x2_mla_depthfirst.hpp
index b20759e..6032f8f 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_s8q_nhwc_3x3_s2_output2x2_mla_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_s8q_nhwc_3x3_s2_output2x2_mla_depthfirst.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -36,37 +36,24 @@
 
 void a64_s8q_nhwc_3x3_s2_output2x2_mla_depthfirst_impl(unsigned int, const int8_t *const *, const int8_t *, const int32_t *, const arm_gemm::Requantize32 &, const int32_t *, const int32_t *, int8_t *const *);
 
-struct a64_s8q_nhwc_3x3_s2_output2x2_mla_depthfirst
+class a64_s8q_nhwc_3x3_s2_output2x2_mla_depthfirst : public DepthwiseDepthfirstStrategy<int8_t, int8_t, int8_t, int32_t>
 {
-  typedef int32_t bias_type;
-  typedef int8_t input_type;
-  typedef int8_t weight_type;
-  typedef int8_t return_type;
+  using Parent = DepthwiseDepthfirstStrategy<int8_t, int8_t, int8_t, int32_t>;
 
-  constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::None;
-
-  typedef void (*kern_type)(unsigned int, const int8_t *const *, const int8_t *, const int32_t *, const arm_gemm::Requantize32 &, const int32_t *, const int32_t *, int8_t *const *);
-  typedef void (*parameter_packing_fn)(unsigned int, void *, const int8_t *, size_t, size_t);
-  typedef size_t (*parameter_sizing_fn)(const DepthwiseArgs &);
-
+  public:
   constexpr static unsigned int kernel_rows = 3;
   constexpr static unsigned int kernel_cols = 3;
 
   constexpr static unsigned int stride_rows = 2;
   constexpr static unsigned int stride_cols = 2;
 
-  constexpr static unsigned int output_rows = 2;
-  constexpr static unsigned int output_cols = 2;
+  a64_s8q_nhwc_3x3_s2_output2x2_mla_depthfirst(const CPUInfo *) : Parent(2, 2, 3, 3, 2, 2) {}
 
-  constexpr static unsigned int input_rows = 5;
-  constexpr static unsigned int input_cols = 5;
+  arm_gemm::VLType get_vl_type(void) const override { return arm_gemm::VLType::None; }
 
-  constexpr static parameter_packing_fn pack_parameters = interleave_a64_s8q_3x3_mla::pack_parameters;
-  constexpr static parameter_sizing_fn get_packed_size = interleave_a64_s8q_3x3_mla::get_packed_size;
-
-  kern_type kernel = a64_s8q_nhwc_3x3_s2_output2x2_mla_depthfirst_impl;
-
-  a64_s8q_nhwc_3x3_s2_output2x2_mla_depthfirst(const CPUInfo *) {}
+  Parent::KernelType kernel = a64_s8q_nhwc_3x3_s2_output2x2_mla_depthfirst_impl;
+  Parent::KernelType get_kernel(void) const override { return kernel; }
+  unsigned int get_accumulator_depth_vl(void) const override { return 2; }
 };
 
 }  // namespace depthwise
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_s8q_nhwc_3x3_s2_output2x2_mla_depthfirst/generic.cpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_s8q_nhwc_3x3_s2_output2x2_mla_depthfirst/generic.cpp
index 3b3d9c8..5499392 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_s8q_nhwc_3x3_s2_output2x2_mla_depthfirst/generic.cpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_s8q_nhwc_3x3_s2_output2x2_mla_depthfirst/generic.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -46,7 +46,7 @@
   struct Params
   {
     long unsigned int n_channels;
-    const int8_t *weights;
+    const void *weights;
     const int32_t *bias;
     const arm_gemm::Requantize32 *requant;
     const int32_t *const requant_muls;
@@ -57,7 +57,7 @@
     Params(
       long unsigned int n_channels,
       const int8_t *const *inptrs_raw,
-      const int8_t *const weights,
+      const void *const weights,
       const int32_t *const bias,
       const arm_gemm::Requantize32 &qp,
       const int32_t *const requant_muls,
@@ -100,75 +100,75 @@
                       requant_muls, requant_shifts, outptrs);
 
   __asm__ __volatile__(
-    "ldr x4, [%x[params], %[offsetof_Params_n_channels]]\n"
-    "mov x5, #0x0\n"
-    "ldr x6, [%x[params], %[offsetof_Params_weights]]\n"
-    "mov x7, #0x0\n"
-    "ldr x22, [%x[params], %[offsetof_Params_requant]]\n"
-    "add x8, %x[params], %[offsetof_Params_inptrs]\n"
-    "ldr x17, [%x[params], %[offsetof_Params_requant_muls]]\n"
-    "lsr x16, x4, #0x3\n"
-    "ldr x15, [%x[params], %[offsetof_Params_requant_shifts]]\n"
-    "add x19, x22, %[offsetof_Requantize32_a_offset]\n"
-    "ldr x21, [%x[params], %[offsetof_Params_outptrs]]\n"
-    "add x20, x22, %[offsetof_Requantize32_b_offset]\n"
-    "ld1r { v12.16b }, [x19]\n"
-    "add x19, x22, %[offsetof_Requantize32_c_offset]\n"
-    "ld1r { v13.16b }, [x20]\n"
-    "add x20, x22, %[offsetof_Requantize32_minval]\n"
-    "ld1r { v11.4s }, [x19]\n"
-    "add x19, x22, %[offsetof_Requantize32_maxval]\n"
-    "ld1r { v19.4s }, [x20]\n"
-    "ld1r { v14.4s }, [x19]\n"
-    "ldp x14, x13, [x21, #0x0]\n"
-    "ldp x12, x11, [x21, #0x10]\n"
+    "ldr x19, [%x[params], %[offsetof_Params_requant]]\n"
+    "ldr x8, [%x[params], %[offsetof_Params_n_channels]]\n"
+    "add x24, x19, %[offsetof_Requantize32_a_offset]\n"
+    "add x23, x19, %[offsetof_Requantize32_b_offset]\n"
+    "ldr x22, [%x[params], %[offsetof_Params_outptrs]]\n"
+    "add x21, x19, %[offsetof_Requantize32_c_offset]\n"
+    "add x20, x19, %[offsetof_Requantize32_minval]\n"
+    "ldr x17, [%x[params], %[offsetof_Params_weights]]\n"
+    "add x19, x19, %[offsetof_Requantize32_maxval]\n"
+    "ld1r { v12.16b }, [x24]\n"
+    "ld1r { v13.16b }, [x23]\n"
+    "lsr x16, x8, #0x3\n"
+    "ld1r { v11.8h }, [x21]\n"
+    "ld1r { v17.8h }, [x20]\n"
+    "mov x15, #0x0\n"
+    "mov x14, #0x0\n"
+    "ld1r { v14.8h }, [x19]\n"
+    "ldr x13, [%x[params], %[offsetof_Params_requant_muls]]\n"
+    "add x12, %x[params], %[offsetof_Params_inptrs]\n"
+    "ldr x11, [%x[params], %[offsetof_Params_requant_shifts]]\n"
+    "ldp x10, x9, [x22, #0x0]\n"
+    "ldp x28, x27, [x22, #0x10]\n"
     "cbz x16, 3f\n"
-    "subs x16, x16, #0x1\n"
     "ldr x19, [%x[params], %[offsetof_Params_bias]]\n"
     "ldr q15, [x19, #0x0]\n"
-    "mov v20.16b, v15.16b\n"
+    "subs x16, x16, #0x1\n"
+    "mov v9.16b, v15.16b\n"
     "ldr q10, [x19, #0x10]\n"
     "add x19, x19, #0x20\n"
-    "mov v16.16b, v15.16b\n"
     "str x19, [%x[params], %[offsetof_Params_bias]]\n"
-    "mov v17.16b, v15.16b\n"
-    "ldr d0, [x6, #0x0]\n"
-    "ssubl v0.8h, v0.8b, v13.8b\n"
-    "mov v23.16b, v10.16b\n"
-    "ldr d1, [x6, #0x8]\n"
-    "mov v22.16b, v10.16b\n"
-    "ldr d2, [x6, #0x10]\n"
-    "ssubl v1.8h, v1.8b, v13.8b\n"
+    "ldr d0, [x17, #0x0]\n"
+    "ldr d1, [x17, #0x8]\n"
+    "ldr d2, [x17, #0x10]\n"
+    "mov v16.16b, v10.16b\n"
+    "mov v22.16b, v15.16b\n"
+    "ldr d3, [x17, #0x18]\n"
+    "ldr d4, [x17, #0x20]\n"
+    "mov v21.16b, v10.16b\n"
+    "mov v23.16b, v15.16b\n"
+    "ldr d5, [x17, #0x28]\n"
+    "ldr d6, [x17, #0x30]\n"
     "mov v18.16b, v10.16b\n"
-    "ldr d3, [x6, #0x18]\n"
-    "ldr d4, [x6, #0x20]\n"
+    "ssubl v0.8h, v0.8b, v13.8b\n"
+    "ldr d7, [x17, #0x38]\n"
+    "ldr d8, [x17, #0x40]\n"
+    "ssubl v1.8h, v1.8b, v13.8b\n"
     "ssubl v2.8h, v2.8b, v13.8b\n"
-    "ldr d5, [x6, #0x28]\n"
+    "ldp x26, x25, [x12, #0x0]\n"
+    "ldp x24, x23, [x12, #0x10]\n"
     "ssubl v3.8h, v3.8b, v13.8b\n"
-    "ldr d6, [x6, #0x30]\n"
-    "ldr d7, [x6, #0x38]\n"
     "ssubl v4.8h, v4.8b, v13.8b\n"
-    "ldr d8, [x6, #0x40]\n"
+    "ldp x22, x21, [x12, #0x20]\n"
+    "ldp x20, x19, [x12, #0x30]\n"
     "ssubl v5.8h, v5.8b, v13.8b\n"
-    "ldp x26, x25, [x8, #0x0]\n"
     "ssubl v6.8h, v6.8b, v13.8b\n"
-    "ldp x24, x23, [x8, #0x10]\n"
+    "ldr d31, [x26, x15]\n"
+    "ldr d30, [x25, x15]\n"
     "ssubl v7.8h, v7.8b, v13.8b\n"
     "ssubl v8.8h, v8.8b, v13.8b\n"
-    "ldp x22, x21, [x8, #0x20]\n"
-    "ldp x20, x19, [x8, #0x30]\n"
-    "ldr d31, [x26, x5]\n"
+    "ldr d29, [x24, x15]\n"
+    "ldr d28, [x23, x15]\n"
     "ssubl v31.8h, v31.8b, v12.8b\n"
-    "ldr d30, [x25, x5]\n"
-    "ldr d29, [x24, x5]\n"
     "ssubl v30.8h, v30.8b, v12.8b\n"
-    "ldr d28, [x23, x5]\n"
-    "ldr d27, [x22, x5]\n"
+    "ldr d27, [x22, x15]\n"
+    "ldr d26, [x21, x15]\n"
     "ssubl v29.8h, v29.8b, v12.8b\n"
-    "ldr d26, [x21, x5]\n"
     "ssubl v28.8h, v28.8b, v12.8b\n"
-    "ldr d25, [x20, x5]\n"
-    "ldr d24, [x19, x5]\n"
+    "ldr d25, [x20, x15]\n"
+    "ldr d24, [x19, x15]\n"
     "ssubl v27.8h, v27.8b, v12.8b\n"
     "ssubl v26.8h, v26.8b, v12.8b\n"
     "ssubl v25.8h, v25.8b, v12.8b\n"
@@ -176,259 +176,251 @@
     "beq 2f\n"
     "1:"  // Loop
     "smlal v15.4s, v31.4h, v8.4h\n"
-    "ldr x23, [x8, #0x40]\n"
-    "add x6, x6, #0x48\n"
     "smlal2 v10.4s, v31.8h, v8.8h\n"
-    "ldr x22, [x8, #0x48]\n"
-    "subs x16, x16, #0x1\n"
-    "smlal v20.4s, v31.4h, v6.4h\n"
-    "ldr x21, [x8, #0x50]\n"
-    "smlal2 v23.4s, v31.8h, v6.8h\n"
-    "ldr x20, [x8, #0x58]\n"
-    "smlal v16.4s, v31.4h, v2.4h\n"
-    "ldr x19, [x8, #0x60]\n"
-    "smlal2 v22.4s, v31.8h, v2.8h\n"
-    "ldr x10, [x8, #0x68]\n"
-    "smlal v17.4s, v31.4h, v0.4h\n"
-    "ldr x9, [x8, #0x70]\n"
-    "smlal2 v18.4s, v31.8h, v0.8h\n"
-    "ldr x28, [x8, #0x78]\n"
+    "ldr x24, [x12, #0x40]\n"
+    "ldr x23, [x12, #0x48]\n"
+    "smlal v9.4s, v31.4h, v6.4h\n"
+    "smlal2 v16.4s, v31.8h, v6.8h\n"
+    "ldr x21, [x12, #0x50]\n"
+    "ldr x19, [x12, #0x58]\n"
     "smlal v15.4s, v30.4h, v0.4h\n"
-    "ldr x27, [x8, #0x80]\n"
     "smlal2 v10.4s, v30.8h, v0.8h\n"
-    "ldr x26, [x8, #0x88]\n"
-    "smlal v20.4s, v28.4h, v1.4h\n"
-    "ldr x25, [x8, #0x90]\n"
-    "smlal2 v23.4s, v28.8h, v1.8h\n"
-    "ldr d28, [x22, x5]\n"
+    "ldr x22, [x12, #0x78]\n"
+    "ldr x20, [x12, #0x60]\n"
+    "smlal v9.4s, v28.4h, v1.4h\n"
+    "smlal2 v16.4s, v28.8h, v1.8h\n"
+    "ldr d28, [x23, x15]\n"
     "ssubl v28.8h, v28.8b, v12.8b\n"
     "smlal v15.4s, v29.4h, v1.4h\n"
-    "ldr x24, [x8, #0x98]\n"
     "smlal2 v10.4s, v29.8h, v1.8h\n"
-    "ldr d29, [x23, x5]\n"
+    "ldr d29, [x24, x15]\n"
     "ssubl v29.8h, v29.8b, v12.8b\n"
-    "smlal v20.4s, v27.4h, v2.4h\n"
-    "ldr x23, [x8, #0xa0]\n"
-    "smlal2 v23.4s, v27.8h, v2.8h\n"
-    "ldr d27, [x21, x5]\n"
+    "smlal v9.4s, v27.4h, v2.4h\n"
+    "smlal2 v16.4s, v27.8h, v2.8h\n"
+    "ldr d27, [x21, x15]\n"
     "ssubl v27.8h, v27.8b, v12.8b\n"
     "smlal v15.4s, v26.4h, v3.4h\n"
-    "ldr x22, [x8, #0xa8]\n"
     "smlal2 v10.4s, v26.8h, v3.8h\n"
-    "ldr d26, [x20, x5]\n"
+    "ldr d26, [x19, x15]\n"
     "ssubl v26.8h, v26.8b, v12.8b\n"
+    "smlal v9.4s, v24.4h, v0.4h\n"
+    "smlal2 v16.4s, v24.8h, v0.8h\n"
+    "ldr x21, [x12, #0x80]\n"
+    "ldr x19, [x12, #0x68]\n"
     "smlal v15.4s, v25.4h, v4.4h\n"
-    "ldr x21, [x8, #0xb0]\n"
     "smlal2 v10.4s, v25.8h, v4.8h\n"
-    "ldr d25, [x19, x5]\n"
+    "ldr d25, [x20, x15]\n"
     "ssubl v25.8h, v25.8b, v12.8b\n"
+    "smlal v9.4s, v29.4h, v4.4h\n"
+    "smlal2 v16.4s, v29.8h, v4.8h\n"
+    "ldr x20, [x12, #0x88]\n"
+    "ldr d29, [x19, x15]\n"
     "smlal v15.4s, v24.4h, v2.4h\n"
-    "ldr x20, [x8, #0xb8]\n"
     "smlal2 v10.4s, v24.8h, v2.8h\n"
-    "ldr x19, [x8, #0xc0]\n"
-    "smlal v20.4s, v24.4h, v0.4h\n"
-    "ldr q21, [x17, #0x0]\n"
-    "smlal2 v23.4s, v24.8h, v0.8h\n"
-    "ldr d24, [x9, x5]\n"
-    "ssubl v24.8h, v24.8b, v12.8b\n"
-    "smlal v20.4s, v29.4h, v4.4h\n"
-    "ldr q30, [x15, #0x0]\n"
-    "smlal2 v23.4s, v29.8h, v4.8h\n"
-    "ldr d29, [x10, x5]\n"
+    "ldr x19, [x12, #0x70]\n"
     "ssubl v29.8h, v29.8b, v12.8b\n"
-    "smlal v20.4s, v28.4h, v5.4h\n"
-    "ldr q31, [x17, #0x10]\n"
-    "smlal2 v23.4s, v28.8h, v5.8h\n"
-    "ldr d28, [x27, x5]\n"
-    "add x17, x17, #0x20\n"
-    "smlal v15.4s, v27.4h, v5.4h\n"
-    "ldr q9, [x15, #0x10]\n"
-    "add x15, x15, #0x20\n"
-    "smlal2 v10.4s, v27.8h, v5.8h\n"
+    "smlal v9.4s, v28.4h, v5.4h\n"
+    "smlal2 v16.4s, v28.8h, v5.8h\n"
+    "ldr d28, [x21, x15]\n"
     "ssubl v28.8h, v28.8b, v12.8b\n"
-    "smlal v20.4s, v27.4h, v3.4h\n"
-    "smlal2 v23.4s, v27.8h, v3.8h\n"
-    "ldr d27, [x28, x5]\n"
-    "ssubl v27.8h, v27.8b, v12.8b\n"
-    "smlal v16.4s, v26.4h, v3.4h\n"
-    "smlal2 v22.4s, v26.8h, v3.8h\n"
-    "ldr d26, [x26, x5]\n"
-    "ssubl v26.8h, v26.8b, v12.8b\n"
-    "smlal v15.4s, v25.4h, v6.4h\n"
-    "smlal2 v10.4s, v25.8h, v6.8h\n"
-    "smlal v16.4s, v25.4h, v0.4h\n"
-    "smlal2 v22.4s, v25.8h, v0.8h\n"
-    "ldr d25, [x25, x5]\n"
-    "ssubl v25.8h, v25.8b, v12.8b\n"
-    "smlal v16.4s, v29.4h, v4.4h\n"
-    "smlal2 v22.4s, v29.8h, v4.8h\n"
-    "ldr d29, [x24, x5]\n"
-    "ssubl v29.8h, v29.8b, v12.8b\n"
-    "smlal v15.4s, v24.4h, v7.4h\n"
-    "smlal2 v10.4s, v24.8h, v7.8h\n"
-    "smlal v16.4s, v24.4h, v1.4h\n"
-    "smlal2 v22.4s, v24.8h, v1.8h\n"
-    "ldr d24, [x22, x5]\n"
+    "smlal v22.4s, v31.4h, v2.4h\n"
+    "smlal2 v21.4s, v31.8h, v2.8h\n"
+    "ldr x24, [x12, #0x98]\n"
+    "ldr d24, [x19, x15]\n"
+    "smlal v15.4s, v27.4h, v5.4h\n"
+    "smlal2 v10.4s, v27.8h, v5.8h\n"
     "ssubl v24.8h, v24.8b, v12.8b\n"
-    "smlal v17.4s, v27.4h, v4.4h\n"
-    "smlal2 v18.4s, v27.8h, v4.8h\n"
-    "ldr d27, [x23, x5]\n"
+    "ldr x23, [x12, #0x90]\n"
+    "smlal v9.4s, v27.4h, v3.4h\n"
+    "smlal2 v16.4s, v27.8h, v3.8h\n"
+    "ldr d27, [x22, x15]\n"
     "ssubl v27.8h, v27.8b, v12.8b\n"
-    "smlal v20.4s, v28.4h, v7.4h\n"
-    "smlal2 v23.4s, v28.8h, v7.8h\n"
-    "smlal v17.4s, v28.4h, v1.4h\n"
-    "smlal2 v18.4s, v28.8h, v1.8h\n"
-    "smlal v16.4s, v25.4h, v6.4h\n"
-    "smlal2 v22.4s, v25.8h, v6.8h\n"
-    "ldr d25, [x20, x5]\n"
-    "ssubl v25.8h, v25.8b, v12.8b\n"
-    "smlal v17.4s, v26.4h, v5.4h\n"
-    "smlal2 v18.4s, v26.8h, v5.8h\n"
-    "ldr d26, [x21, x5]\n"
+    "smlal v23.4s, v31.4h, v0.4h\n"
+    "smlal v22.4s, v26.4h, v3.4h\n"
+    "ldr x22, [x12, #0xa8]\n"
+    "ldr x19, [x12, #0xa0]\n"
+    "smlal2 v21.4s, v26.8h, v3.8h\n"
+    "smlal2 v18.4s, v31.8h, v0.8h\n"
+    "ldr d26, [x20, x15]\n"
     "ssubl v26.8h, v26.8b, v12.8b\n"
-    "smlal v20.4s, v29.4h, v8.4h\n"
-    "smlal2 v23.4s, v29.8h, v8.8h\n"
-    "smlal v17.4s, v29.4h, v2.4h\n"
-    "smlal2 v18.4s, v29.8h, v2.8h\n"
-    "ldr d29, [x19, x5]\n"
-    "add x5, x5, #0x8\n"
-    "smlal v16.4s, v27.4h, v7.4h\n"
+    "smlal v23.4s, v27.4h, v4.4h\n"
+    "smlal v22.4s, v25.4h, v0.4h\n"
+    "ldr x21, [x12, #0xb0]\n"
+    "ldr x20, [x12, #0xb8]\n"
+    "smlal2 v21.4s, v25.8h, v0.8h\n"
+    "smlal2 v18.4s, v27.8h, v4.8h\n"
+    "ldr d27, [x19, x15]\n"
+    "ssubl v27.8h, v27.8b, v12.8b\n"
+    "smlal v23.4s, v28.4h, v1.4h\n"
+    "smlal v15.4s, v25.4h, v6.4h\n"
+    "ldr x19, [x12, #0xc0]\n"
+    "ldr q19, [x13, #0x0]\n"
+    "smlal2 v10.4s, v25.8h, v6.8h\n"
+    "smlal v22.4s, v29.4h, v4.4h\n"
+    "ldr d25, [x23, x15]\n"
+    "ssubl v25.8h, v25.8b, v12.8b\n"
+    "smlal2 v21.4s, v29.8h, v4.8h\n"
+    "ldr d29, [x24, x15]\n"
+    "smlal2 v18.4s, v28.8h, v1.8h\n"
     "ssubl v29.8h, v29.8b, v12.8b\n"
-    "smlal2 v22.4s, v27.8h, v7.8h\n"
-    "smlal v17.4s, v24.4h, v3.4h\n"
-    "smlal v16.4s, v24.4h, v5.4h\n"
+    "smlal v23.4s, v26.4h, v5.4h\n"
+    "smlal v15.4s, v24.4h, v7.4h\n"
+    "ldr q0, [x11, #0x0]\n"
+    "ldr q4, [x13, #0x10]\n"
+    "smlal2 v10.4s, v24.8h, v7.8h\n"
+    "smlal v22.4s, v24.4h, v1.4h\n"
+    "sqdmulh v15.4s, v15.4s, v19.4s\n"
+    "ldr q31, [x11, #0x10]\n"
+    "smlal2 v21.4s, v24.8h, v1.8h\n"
+    "ldr d24, [x22, x15]\n"
+    "smlal2 v18.4s, v26.8h, v5.8h\n"
+    "ssubl v24.8h, v24.8b, v12.8b\n"
+    "smlal v23.4s, v29.4h, v2.4h\n"
+    "ldr d26, [x21, x15]\n"
+    "smlal2 v18.4s, v29.8h, v2.8h\n"
+    "ssubl v26.8h, v26.8b, v12.8b\n"
+    "smlal v22.4s, v25.4h, v6.4h\n"
+    "smlal v23.4s, v24.4h, v3.4h\n"
+    "and v30.16b, v15.16b, v0.16b\n"
+    "add x17, x17, #0x48\n"
+    "smlal v9.4s, v28.4h, v7.4h\n"
+    "smlal2 v16.4s, v28.8h, v7.8h\n"
+    "sqdmulh v10.4s, v10.4s, v4.4s\n"
+    "subs x16, x16, #0x1\n"
+    "smlal2 v21.4s, v25.8h, v6.8h\n"
+    "ldr d25, [x20, x15]\n"
     "smlal2 v18.4s, v24.8h, v3.8h\n"
-    "sqrdmulh v15.4s, v15.4s, v21.4s\n"
-    "smlal2 v22.4s, v24.8h, v5.8h\n"
-    "smlal v17.4s, v26.4h, v7.4h\n"
+    "ssubl v25.8h, v25.8b, v12.8b\n"
+    "smlal v22.4s, v27.4h, v7.4h\n"
+    "smlal v23.4s, v26.4h, v7.4h\n"
+    "sshr v30.4s, v30.4s, #0x1f\n"
+    "add x13, x13, #0x20\n"
+    "smlal v9.4s, v29.4h, v8.4h\n"
+    "smlal2 v16.4s, v29.8h, v8.8h\n"
+    "ldr d29, [x19, x15]\n"
+    "ssubl v29.8h, v29.8b, v12.8b\n"
+    "smlal2 v21.4s, v27.8h, v7.8h\n"
     "smlal2 v18.4s, v26.8h, v7.8h\n"
-    "smlal v16.4s, v25.4h, v8.4h\n"
-    "smlal2 v22.4s, v25.8h, v8.8h\n"
-    "smlal v17.4s, v25.4h, v6.4h\n"
+    "sqdmulh v9.4s, v9.4s, v19.4s\n"
+    "add x15, x15, #0x8\n"
+    "smlal v22.4s, v24.4h, v5.4h\n"
+    "smlal v23.4s, v25.4h, v6.4h\n"
+    "and v28.16b, v9.16b, v0.16b\n"
+    "add x11, x11, #0x20\n"
+    "smlal2 v21.4s, v24.8h, v5.8h\n"
     "smlal2 v18.4s, v25.8h, v6.8h\n"
-    "and v26.16b, v15.16b, v30.16b\n"
-    "sshr v26.4s, v26.4s, #0x1f\n"
-    "smlal v17.4s, v29.4h, v8.4h\n"
+    "sqdmulh v16.4s, v16.4s, v4.4s\n"
+    "smlal v22.4s, v25.4h, v8.4h\n"
+    "smlal v23.4s, v29.4h, v8.4h\n"
+    "sqdmulh v22.4s, v22.4s, v19.4s\n"
+    "smlal2 v21.4s, v25.8h, v8.8h\n"
     "smlal2 v18.4s, v29.8h, v8.8h\n"
-    "sqrdmulh v10.4s, v10.4s, v31.4s\n"
-    "sqrdmulh v20.4s, v20.4s, v21.4s\n"
-    "sqrdmulh v23.4s, v23.4s, v31.4s\n"
-    "sqrdmulh v16.4s, v16.4s, v21.4s\n"
-    "sqadd v15.4s, v15.4s, v26.4s\n"
-    "and v8.16b, v10.16b, v9.16b\n"
-    "sshr v8.4s, v8.4s, #0x1f\n"
-    "srshl v15.4s, v15.4s, v30.4s\n"
-    "and v4.16b, v20.16b, v30.16b\n"
+    "sqdmulh v23.4s, v23.4s, v19.4s\n"
+    "and v29.16b, v22.16b, v0.16b\n"
+    "sqdmulh v21.4s, v21.4s, v4.4s\n"
+    "and v20.16b, v23.16b, v0.16b\n"
+    "sqdmulh v18.4s, v18.4s, v4.4s\n"
+    "and v19.16b, v10.16b, v31.16b\n"
+    "sshr v28.4s, v28.4s, #0x1f\n"
+    "and v4.16b, v16.16b, v31.16b\n"
+    "sshr v29.4s, v29.4s, #0x1f\n"
+    "and v5.16b, v21.16b, v31.16b\n"
+    "sshr v20.4s, v20.4s, #0x1f\n"
+    "and v26.16b, v18.16b, v31.16b\n"
+    "sqadd v15.4s, v15.4s, v30.4s\n"
+    "sshr v19.4s, v19.4s, #0x1f\n"
+    "sqadd v9.4s, v9.4s, v28.4s\n"
     "sshr v4.4s, v4.4s, #0x1f\n"
-    "and v2.16b, v23.16b, v9.16b\n"
-    "and v1.16b, v16.16b, v30.16b\n"
-    "sshr v2.4s, v2.4s, #0x1f\n"
-    "add v15.4s, v15.4s, v11.4s\n"
-    "sqadd v10.4s, v10.4s, v8.4s\n"
-    "sshr v1.4s, v1.4s, #0x1f\n"
-    "sqrdmulh v22.4s, v22.4s, v31.4s\n"
-    "sqadd v20.4s, v20.4s, v4.4s\n"
-    "smin v15.4s, v15.4s, v14.4s\n"
-    "srshl v10.4s, v10.4s, v9.4s\n"
-    "sqadd v23.4s, v23.4s, v2.4s\n"
-    "smax v15.4s, v15.4s, v19.4s\n"
-    "srshl v20.4s, v20.4s, v30.4s\n"
-    "add v10.4s, v10.4s, v11.4s\n"
-    "srshl v23.4s, v23.4s, v9.4s\n"
-    "sqadd v16.4s, v16.4s, v1.4s\n"
-    "smin v10.4s, v10.4s, v14.4s\n"
-    "add v20.4s, v20.4s, v11.4s\n"
-    "add v23.4s, v23.4s, v11.4s\n"
-    "smax v10.4s, v10.4s, v19.4s\n"
-    "smin v20.4s, v20.4s, v14.4s\n"
-    "smin v23.4s, v23.4s, v14.4s\n"
-    "uzp1 v15.16b, v15.16b, v10.16b\n"
-    "smax v20.4s, v20.4s, v19.4s\n"
+    "sqadd v22.4s, v22.4s, v29.4s\n"
+    "sshr v5.4s, v5.4s, #0x1f\n"
+    "sqadd v23.4s, v23.4s, v20.4s\n"
+    "sshr v26.4s, v26.4s, #0x1f\n"
+    "srshl v15.4s, v15.4s, v0.4s\n"
+    "sqadd v10.4s, v10.4s, v19.4s\n"
+    "srshl v9.4s, v9.4s, v0.4s\n"
+    "sqadd v16.4s, v16.4s, v4.4s\n"
+    "srshl v22.4s, v22.4s, v0.4s\n"
+    "sqadd v21.4s, v21.4s, v5.4s\n"
+    "srshl v23.4s, v23.4s, v0.4s\n"
+    "sqadd v18.4s, v18.4s, v26.4s\n"
+    "srshl v10.4s, v10.4s, v31.4s\n"
+    "sqxtn v15.4h, v15.4s\n"
+    "srshl v16.4s, v16.4s, v31.4s\n"
+    "sqxtn v9.4h, v9.4s\n"
+    "srshl v21.4s, v21.4s, v31.4s\n"
+    "sqxtn v22.4h, v22.4s\n"
+    "srshl v18.4s, v18.4s, v31.4s\n"
+    "sqxtn v23.4h, v23.4s\n"
+    "sqxtn2 v15.8h, v10.4s\n"
+    "sqxtn2 v9.8h, v16.4s\n"
+    "sqxtn2 v22.8h, v21.4s\n"
+    "sqxtn2 v23.8h, v18.4s\n"
+    "sqadd v15.8h, v15.8h, v11.8h\n"
+    "sqadd v9.8h, v9.8h, v11.8h\n"
+    "sqadd v22.8h, v22.8h, v11.8h\n"
+    "sqadd v23.8h, v23.8h, v11.8h\n"
+    "smax v15.8h, v15.8h, v17.8h\n"
+    "smax v9.8h, v9.8h, v17.8h\n"
+    "smax v22.8h, v22.8h, v17.8h\n"
+    "smax v23.8h, v23.8h, v17.8h\n"
+    "smin v15.8h, v15.8h, v14.8h\n"
+    "smin v9.8h, v9.8h, v14.8h\n"
+    "smin v22.8h, v22.8h, v14.8h\n"
+    "smin v23.8h, v23.8h, v14.8h\n"
     "uzp1 v15.16b, v15.16b, v15.16b\n"
-    "str d15, [x14, x7]\n"
-    "smax v23.4s, v23.4s, v19.4s\n"
-    "srshl v16.4s, v16.4s, v30.4s\n"
-    "and v24.16b, v22.16b, v9.16b\n"
-    "sshr v24.4s, v24.4s, #0x1f\n"
-    "uzp1 v20.16b, v20.16b, v23.16b\n"
-    "add v16.4s, v16.4s, v11.4s\n"
-    "sqrdmulh v17.4s, v17.4s, v21.4s\n"
-    "uzp1 v20.16b, v20.16b, v20.16b\n"
-    "str d20, [x13, x7]\n"
-    "smin v16.4s, v16.4s, v14.4s\n"
-    "sqrdmulh v18.4s, v18.4s, v31.4s\n"
-    "sqadd v22.4s, v22.4s, v24.4s\n"
-    "and v2.16b, v17.16b, v30.16b\n"
-    "sshr v2.4s, v2.4s, #0x1f\n"
-    "smax v16.4s, v16.4s, v19.4s\n"
-    "srshl v22.4s, v22.4s, v9.4s\n"
-    "and v31.16b, v18.16b, v9.16b\n"
-    "sshr v31.4s, v31.4s, #0x1f\n"
-    "add v22.4s, v22.4s, v11.4s\n"
-    "sqadd v17.4s, v17.4s, v2.4s\n"
-    "smin v22.4s, v22.4s, v14.4s\n"
-    "srshl v17.4s, v17.4s, v30.4s\n"
-    "sqadd v18.4s, v18.4s, v31.4s\n"
-    "smax v22.4s, v22.4s, v19.4s\n"
-    "uzp1 v16.16b, v16.16b, v22.16b\n"
-    "add v17.4s, v17.4s, v11.4s\n"
-    "srshl v18.4s, v18.4s, v9.4s\n"
-    "uzp1 v16.16b, v16.16b, v16.16b\n"
-    "str d16, [x12, x7]\n"
-    "smin v17.4s, v17.4s, v14.4s\n"
-    "add v18.4s, v18.4s, v11.4s\n"
-    "smax v17.4s, v17.4s, v19.4s\n"
-    "smin v18.4s, v18.4s, v14.4s\n"
-    "smax v18.4s, v18.4s, v19.4s\n"
-    "uzp1 v17.16b, v17.16b, v18.16b\n"
-    "uzp1 v17.16b, v17.16b, v17.16b\n"
-    "str d17, [x11, x7]\n"
-    "add x7, x7, #0x8\n"
+    "str d15, [x10, x14]\n"
+    "uzp1 v9.16b, v9.16b, v9.16b\n"
+    "uzp1 v22.16b, v22.16b, v22.16b\n"
+    "str d9, [x9, x14]\n"
+    "uzp1 v23.16b, v23.16b, v23.16b\n"
+    "str d22, [x28, x14]\n"
+    "str d23, [x27, x14]\n"
     "ldr x19, [%x[params], %[offsetof_Params_bias]]\n"
     "ldr q15, [x19, #0x0]\n"
-    "mov v20.16b, v15.16b\n"
+    "add x14, x14, #0x8\n"
     "ldr q10, [x19, #0x10]\n"
     "add x19, x19, #0x20\n"
-    "mov v16.16b, v15.16b\n"
     "str x19, [%x[params], %[offsetof_Params_bias]]\n"
-    "mov v17.16b, v15.16b\n"
-    "ldr d0, [x6, #0x0]\n"
-    "ssubl v0.8h, v0.8b, v13.8b\n"
-    "mov v23.16b, v10.16b\n"
-    "ldr d1, [x6, #0x8]\n"
-    "mov v22.16b, v10.16b\n"
-    "ldr d2, [x6, #0x10]\n"
-    "ssubl v1.8h, v1.8b, v13.8b\n"
+    "ldr d0, [x17, #0x0]\n"
+    "ldr d1, [x17, #0x8]\n"
+    "ldr d2, [x17, #0x10]\n"
+    "mov v9.16b, v15.16b\n"
+    "mov v16.16b, v10.16b\n"
+    "ldr d3, [x17, #0x18]\n"
+    "ldr d4, [x17, #0x20]\n"
+    "mov v22.16b, v15.16b\n"
+    "mov v21.16b, v10.16b\n"
+    "ldr d5, [x17, #0x28]\n"
+    "ldr d6, [x17, #0x30]\n"
+    "mov v23.16b, v15.16b\n"
     "mov v18.16b, v10.16b\n"
-    "ldr d3, [x6, #0x18]\n"
-    "ldr d4, [x6, #0x20]\n"
+    "ldr d7, [x17, #0x38]\n"
+    "ldr d8, [x17, #0x40]\n"
+    "ssubl v0.8h, v0.8b, v13.8b\n"
+    "ssubl v1.8h, v1.8b, v13.8b\n"
+    "ldp x26, x25, [x12, #0x0]\n"
+    "ldp x24, x23, [x12, #0x10]\n"
     "ssubl v2.8h, v2.8b, v13.8b\n"
-    "ldr d5, [x6, #0x28]\n"
     "ssubl v3.8h, v3.8b, v13.8b\n"
-    "ldr d6, [x6, #0x30]\n"
-    "ldr d7, [x6, #0x38]\n"
+    "ldp x22, x21, [x12, #0x20]\n"
+    "ldp x20, x19, [x12, #0x30]\n"
     "ssubl v4.8h, v4.8b, v13.8b\n"
-    "ldr d8, [x6, #0x40]\n"
     "ssubl v5.8h, v5.8b, v13.8b\n"
-    "ldp x26, x25, [x8, #0x0]\n"
+    "ldr d31, [x26, x15]\n"
+    "ldr d30, [x25, x15]\n"
     "ssubl v6.8h, v6.8b, v13.8b\n"
-    "ldp x24, x23, [x8, #0x10]\n"
     "ssubl v7.8h, v7.8b, v13.8b\n"
+    "ldr d29, [x24, x15]\n"
+    "ldr d28, [x23, x15]\n"
     "ssubl v8.8h, v8.8b, v13.8b\n"
-    "ldp x22, x21, [x8, #0x20]\n"
-    "ldp x20, x19, [x8, #0x30]\n"
-    "ldr d31, [x26, x5]\n"
     "ssubl v31.8h, v31.8b, v12.8b\n"
-    "ldr d30, [x25, x5]\n"
-    "ldr d29, [x24, x5]\n"
+    "ldr d27, [x22, x15]\n"
+    "ldr d26, [x21, x15]\n"
     "ssubl v30.8h, v30.8b, v12.8b\n"
-    "ldr d28, [x23, x5]\n"
-    "ldr d27, [x22, x5]\n"
     "ssubl v29.8h, v29.8b, v12.8b\n"
-    "ldr d26, [x21, x5]\n"
+    "ldr d25, [x20, x15]\n"
+    "ldr d24, [x19, x15]\n"
     "ssubl v28.8h, v28.8b, v12.8b\n"
-    "ldr d25, [x20, x5]\n"
-    "ldr d24, [x19, x5]\n"
     "ssubl v27.8h, v27.8b, v12.8b\n"
     "ssubl v26.8h, v26.8b, v12.8b\n"
     "ssubl v25.8h, v25.8b, v12.8b\n"
@@ -436,275 +428,267 @@
     "bgt 1b\n"
     "2:"  // Tail
     "smlal v15.4s, v31.4h, v8.4h\n"
-    "ldr x23, [x8, #0x40]\n"
-    "tst x4, #0x7\n"
     "smlal2 v10.4s, v31.8h, v8.8h\n"
-    "ldr x22, [x8, #0x48]\n"
-    "smlal v20.4s, v31.4h, v6.4h\n"
-    "ldr x21, [x8, #0x50]\n"
-    "smlal2 v23.4s, v31.8h, v6.8h\n"
-    "ldr x20, [x8, #0x58]\n"
-    "smlal v16.4s, v31.4h, v2.4h\n"
-    "ldr x19, [x8, #0x60]\n"
-    "smlal2 v22.4s, v31.8h, v2.8h\n"
-    "ldr x10, [x8, #0x68]\n"
-    "smlal v17.4s, v31.4h, v0.4h\n"
-    "ldr x9, [x8, #0x70]\n"
-    "smlal2 v18.4s, v31.8h, v0.8h\n"
-    "ldr x28, [x8, #0x78]\n"
+    "ldr x24, [x12, #0x40]\n"
+    "ldr x23, [x12, #0x48]\n"
+    "smlal v9.4s, v31.4h, v6.4h\n"
+    "smlal2 v16.4s, v31.8h, v6.8h\n"
+    "ldr x21, [x12, #0x50]\n"
+    "ldr x19, [x12, #0x58]\n"
     "smlal v15.4s, v30.4h, v0.4h\n"
-    "ldr x27, [x8, #0x80]\n"
     "smlal2 v10.4s, v30.8h, v0.8h\n"
-    "ldr x26, [x8, #0x88]\n"
-    "smlal v20.4s, v28.4h, v1.4h\n"
-    "ldr x25, [x8, #0x90]\n"
-    "smlal2 v23.4s, v28.8h, v1.8h\n"
-    "ldr d28, [x22, x5]\n"
+    "ldr x22, [x12, #0x78]\n"
+    "ldr x20, [x12, #0x60]\n"
+    "smlal v9.4s, v28.4h, v1.4h\n"
+    "smlal2 v16.4s, v28.8h, v1.8h\n"
+    "ldr d28, [x23, x15]\n"
     "ssubl v28.8h, v28.8b, v12.8b\n"
     "smlal v15.4s, v29.4h, v1.4h\n"
-    "ldr x24, [x8, #0x98]\n"
     "smlal2 v10.4s, v29.8h, v1.8h\n"
-    "ldr d29, [x23, x5]\n"
+    "ldr d29, [x24, x15]\n"
     "ssubl v29.8h, v29.8b, v12.8b\n"
-    "smlal v20.4s, v27.4h, v2.4h\n"
-    "ldr x23, [x8, #0xa0]\n"
-    "smlal2 v23.4s, v27.8h, v2.8h\n"
-    "ldr d27, [x21, x5]\n"
+    "smlal v9.4s, v27.4h, v2.4h\n"
+    "smlal2 v16.4s, v27.8h, v2.8h\n"
+    "ldr d27, [x21, x15]\n"
     "ssubl v27.8h, v27.8b, v12.8b\n"
     "smlal v15.4s, v26.4h, v3.4h\n"
-    "ldr x22, [x8, #0xa8]\n"
     "smlal2 v10.4s, v26.8h, v3.8h\n"
-    "ldr d26, [x20, x5]\n"
+    "ldr d26, [x19, x15]\n"
     "ssubl v26.8h, v26.8b, v12.8b\n"
+    "smlal v9.4s, v24.4h, v0.4h\n"
+    "smlal2 v16.4s, v24.8h, v0.8h\n"
+    "ldr x21, [x12, #0x80]\n"
+    "ldr x19, [x12, #0x68]\n"
     "smlal v15.4s, v25.4h, v4.4h\n"
-    "ldr x21, [x8, #0xb0]\n"
     "smlal2 v10.4s, v25.8h, v4.8h\n"
-    "ldr d25, [x19, x5]\n"
+    "ldr d25, [x20, x15]\n"
     "ssubl v25.8h, v25.8b, v12.8b\n"
+    "smlal v9.4s, v29.4h, v4.4h\n"
+    "smlal2 v16.4s, v29.8h, v4.8h\n"
+    "ldr x20, [x12, #0x88]\n"
+    "ldr d29, [x19, x15]\n"
     "smlal v15.4s, v24.4h, v2.4h\n"
-    "ldr x20, [x8, #0xb8]\n"
     "smlal2 v10.4s, v24.8h, v2.8h\n"
-    "ldr x19, [x8, #0xc0]\n"
-    "smlal v20.4s, v24.4h, v0.4h\n"
-    "ldr q21, [x17, #0x0]\n"
-    "smlal2 v23.4s, v24.8h, v0.8h\n"
-    "ldr d24, [x9, x5]\n"
-    "ssubl v24.8h, v24.8b, v12.8b\n"
-    "smlal v20.4s, v29.4h, v4.4h\n"
-    "ldr q30, [x15, #0x0]\n"
-    "smlal2 v23.4s, v29.8h, v4.8h\n"
-    "ldr d29, [x10, x5]\n"
+    "ldr x19, [x12, #0x70]\n"
     "ssubl v29.8h, v29.8b, v12.8b\n"
-    "smlal v20.4s, v28.4h, v5.4h\n"
-    "ldr q31, [x17, #0x10]\n"
-    "smlal2 v23.4s, v28.8h, v5.8h\n"
-    "ldr d28, [x27, x5]\n"
-    "add x17, x17, #0x20\n"
-    "smlal v15.4s, v27.4h, v5.4h\n"
-    "ldr q9, [x15, #0x10]\n"
-    "add x15, x15, #0x20\n"
-    "smlal2 v10.4s, v27.8h, v5.8h\n"
+    "smlal v9.4s, v28.4h, v5.4h\n"
+    "smlal2 v16.4s, v28.8h, v5.8h\n"
+    "ldr d28, [x21, x15]\n"
     "ssubl v28.8h, v28.8b, v12.8b\n"
-    "smlal v20.4s, v27.4h, v3.4h\n"
-    "smlal2 v23.4s, v27.8h, v3.8h\n"
-    "ldr d27, [x28, x5]\n"
-    "ssubl v27.8h, v27.8b, v12.8b\n"
-    "smlal v16.4s, v26.4h, v3.4h\n"
-    "smlal2 v22.4s, v26.8h, v3.8h\n"
-    "ldr d26, [x26, x5]\n"
-    "ssubl v26.8h, v26.8b, v12.8b\n"
-    "smlal v15.4s, v25.4h, v6.4h\n"
-    "smlal2 v10.4s, v25.8h, v6.8h\n"
-    "smlal v16.4s, v25.4h, v0.4h\n"
-    "smlal2 v22.4s, v25.8h, v0.8h\n"
-    "ldr d25, [x25, x5]\n"
-    "ssubl v25.8h, v25.8b, v12.8b\n"
-    "smlal v16.4s, v29.4h, v4.4h\n"
-    "smlal2 v22.4s, v29.8h, v4.8h\n"
-    "ldr d29, [x24, x5]\n"
-    "ssubl v29.8h, v29.8b, v12.8b\n"
-    "smlal v15.4s, v24.4h, v7.4h\n"
-    "smlal2 v10.4s, v24.8h, v7.8h\n"
-    "smlal v16.4s, v24.4h, v1.4h\n"
-    "smlal2 v22.4s, v24.8h, v1.8h\n"
-    "ldr d24, [x22, x5]\n"
+    "smlal v22.4s, v31.4h, v2.4h\n"
+    "smlal2 v21.4s, v31.8h, v2.8h\n"
+    "ldr x24, [x12, #0x98]\n"
+    "ldr d24, [x19, x15]\n"
+    "smlal v15.4s, v27.4h, v5.4h\n"
+    "smlal2 v10.4s, v27.8h, v5.8h\n"
     "ssubl v24.8h, v24.8b, v12.8b\n"
-    "smlal v17.4s, v27.4h, v4.4h\n"
-    "smlal2 v18.4s, v27.8h, v4.8h\n"
-    "ldr d27, [x23, x5]\n"
+    "ldr x23, [x12, #0x90]\n"
+    "smlal v9.4s, v27.4h, v3.4h\n"
+    "smlal2 v16.4s, v27.8h, v3.8h\n"
+    "ldr d27, [x22, x15]\n"
     "ssubl v27.8h, v27.8b, v12.8b\n"
-    "smlal v20.4s, v28.4h, v7.4h\n"
-    "smlal2 v23.4s, v28.8h, v7.8h\n"
-    "smlal v17.4s, v28.4h, v1.4h\n"
-    "smlal2 v18.4s, v28.8h, v1.8h\n"
-    "smlal v16.4s, v25.4h, v6.4h\n"
-    "smlal2 v22.4s, v25.8h, v6.8h\n"
-    "ldr d25, [x20, x5]\n"
-    "ssubl v25.8h, v25.8b, v12.8b\n"
-    "smlal v17.4s, v26.4h, v5.4h\n"
-    "smlal2 v18.4s, v26.8h, v5.8h\n"
-    "ldr d26, [x21, x5]\n"
+    "smlal v23.4s, v31.4h, v0.4h\n"
+    "smlal v22.4s, v26.4h, v3.4h\n"
+    "ldr x22, [x12, #0xa8]\n"
+    "ldr x19, [x12, #0xa0]\n"
+    "smlal2 v21.4s, v26.8h, v3.8h\n"
+    "smlal2 v18.4s, v31.8h, v0.8h\n"
+    "ldr d26, [x20, x15]\n"
     "ssubl v26.8h, v26.8b, v12.8b\n"
-    "smlal v20.4s, v29.4h, v8.4h\n"
-    "smlal2 v23.4s, v29.8h, v8.8h\n"
-    "smlal v17.4s, v29.4h, v2.4h\n"
-    "smlal2 v18.4s, v29.8h, v2.8h\n"
-    "ldr d29, [x19, x5]\n"
-    "add x5, x5, #0x8\n"
-    "smlal v16.4s, v27.4h, v7.4h\n"
+    "smlal v23.4s, v27.4h, v4.4h\n"
+    "smlal v22.4s, v25.4h, v0.4h\n"
+    "ldr x21, [x12, #0xb0]\n"
+    "ldr x20, [x12, #0xb8]\n"
+    "smlal2 v21.4s, v25.8h, v0.8h\n"
+    "smlal2 v18.4s, v27.8h, v4.8h\n"
+    "ldr d27, [x19, x15]\n"
+    "ssubl v27.8h, v27.8b, v12.8b\n"
+    "smlal v23.4s, v28.4h, v1.4h\n"
+    "smlal v15.4s, v25.4h, v6.4h\n"
+    "ldr x19, [x12, #0xc0]\n"
+    "ldr q19, [x13, #0x0]\n"
+    "smlal2 v10.4s, v25.8h, v6.8h\n"
+    "smlal v22.4s, v29.4h, v4.4h\n"
+    "ldr d25, [x23, x15]\n"
+    "ssubl v25.8h, v25.8b, v12.8b\n"
+    "smlal2 v21.4s, v29.8h, v4.8h\n"
+    "ldr d29, [x24, x15]\n"
+    "smlal2 v18.4s, v28.8h, v1.8h\n"
     "ssubl v29.8h, v29.8b, v12.8b\n"
-    "smlal2 v22.4s, v27.8h, v7.8h\n"
-    "smlal v17.4s, v24.4h, v3.4h\n"
-    "smlal v16.4s, v24.4h, v5.4h\n"
+    "smlal v23.4s, v26.4h, v5.4h\n"
+    "smlal v15.4s, v24.4h, v7.4h\n"
+    "ldr q0, [x11, #0x0]\n"
+    "ldr q4, [x13, #0x10]\n"
+    "smlal2 v10.4s, v24.8h, v7.8h\n"
+    "smlal v22.4s, v24.4h, v1.4h\n"
+    "sqdmulh v15.4s, v15.4s, v19.4s\n"
+    "ldr q31, [x11, #0x10]\n"
+    "smlal2 v21.4s, v24.8h, v1.8h\n"
+    "ldr d24, [x22, x15]\n"
+    "smlal2 v18.4s, v26.8h, v5.8h\n"
+    "ssubl v24.8h, v24.8b, v12.8b\n"
+    "smlal v23.4s, v29.4h, v2.4h\n"
+    "ldr d26, [x21, x15]\n"
+    "smlal2 v18.4s, v29.8h, v2.8h\n"
+    "ssubl v26.8h, v26.8b, v12.8b\n"
+    "smlal v22.4s, v25.4h, v6.4h\n"
+    "smlal v23.4s, v24.4h, v3.4h\n"
+    "and v30.16b, v15.16b, v0.16b\n"
+    "tst x8, #0x7\n"
+    "smlal v9.4s, v28.4h, v7.4h\n"
+    "smlal2 v16.4s, v28.8h, v7.8h\n"
+    "sqdmulh v10.4s, v10.4s, v4.4s\n"
+    "add x13, x13, #0x20\n"
+    "smlal2 v21.4s, v25.8h, v6.8h\n"
+    "ldr d25, [x20, x15]\n"
     "smlal2 v18.4s, v24.8h, v3.8h\n"
-    "sqrdmulh v15.4s, v15.4s, v21.4s\n"
-    "smlal2 v22.4s, v24.8h, v5.8h\n"
-    "smlal v17.4s, v26.4h, v7.4h\n"
+    "ssubl v25.8h, v25.8b, v12.8b\n"
+    "smlal v22.4s, v27.4h, v7.4h\n"
+    "smlal v23.4s, v26.4h, v7.4h\n"
+    "sshr v30.4s, v30.4s, #0x1f\n"
+    "add x11, x11, #0x20\n"
+    "smlal v9.4s, v29.4h, v8.4h\n"
+    "smlal2 v16.4s, v29.8h, v8.8h\n"
+    "ldr d29, [x19, x15]\n"
+    "ssubl v29.8h, v29.8b, v12.8b\n"
+    "smlal2 v21.4s, v27.8h, v7.8h\n"
     "smlal2 v18.4s, v26.8h, v7.8h\n"
-    "smlal v16.4s, v25.4h, v8.4h\n"
-    "smlal2 v22.4s, v25.8h, v8.8h\n"
-    "smlal v17.4s, v25.4h, v6.4h\n"
+    "sqdmulh v9.4s, v9.4s, v19.4s\n"
+    "add x15, x15, #0x8\n"
+    "smlal v22.4s, v24.4h, v5.4h\n"
+    "smlal v23.4s, v25.4h, v6.4h\n"
+    "and v28.16b, v9.16b, v0.16b\n"
+    "smlal2 v21.4s, v24.8h, v5.8h\n"
     "smlal2 v18.4s, v25.8h, v6.8h\n"
-    "and v26.16b, v15.16b, v30.16b\n"
-    "sshr v26.4s, v26.4s, #0x1f\n"
-    "smlal v17.4s, v29.4h, v8.4h\n"
+    "sqdmulh v16.4s, v16.4s, v4.4s\n"
+    "smlal v22.4s, v25.4h, v8.4h\n"
+    "smlal v23.4s, v29.4h, v8.4h\n"
+    "sqdmulh v22.4s, v22.4s, v19.4s\n"
+    "smlal2 v21.4s, v25.8h, v8.8h\n"
     "smlal2 v18.4s, v29.8h, v8.8h\n"
-    "sqrdmulh v10.4s, v10.4s, v31.4s\n"
-    "sqrdmulh v20.4s, v20.4s, v21.4s\n"
-    "sqrdmulh v23.4s, v23.4s, v31.4s\n"
-    "sqrdmulh v16.4s, v16.4s, v21.4s\n"
-    "sqadd v15.4s, v15.4s, v26.4s\n"
-    "and v8.16b, v10.16b, v9.16b\n"
-    "sshr v8.4s, v8.4s, #0x1f\n"
-    "srshl v15.4s, v15.4s, v30.4s\n"
-    "and v4.16b, v20.16b, v30.16b\n"
+    "sqdmulh v23.4s, v23.4s, v19.4s\n"
+    "and v29.16b, v22.16b, v0.16b\n"
+    "sqdmulh v21.4s, v21.4s, v4.4s\n"
+    "and v20.16b, v23.16b, v0.16b\n"
+    "sqdmulh v18.4s, v18.4s, v4.4s\n"
+    "and v19.16b, v10.16b, v31.16b\n"
+    "sshr v28.4s, v28.4s, #0x1f\n"
+    "and v4.16b, v16.16b, v31.16b\n"
+    "sshr v29.4s, v29.4s, #0x1f\n"
+    "and v5.16b, v21.16b, v31.16b\n"
+    "sshr v20.4s, v20.4s, #0x1f\n"
+    "and v26.16b, v18.16b, v31.16b\n"
+    "sqadd v15.4s, v15.4s, v30.4s\n"
+    "sshr v19.4s, v19.4s, #0x1f\n"
+    "sqadd v9.4s, v9.4s, v28.4s\n"
     "sshr v4.4s, v4.4s, #0x1f\n"
-    "and v2.16b, v23.16b, v9.16b\n"
-    "and v1.16b, v16.16b, v30.16b\n"
-    "sshr v2.4s, v2.4s, #0x1f\n"
-    "add v15.4s, v15.4s, v11.4s\n"
-    "sqadd v10.4s, v10.4s, v8.4s\n"
-    "sshr v1.4s, v1.4s, #0x1f\n"
-    "sqrdmulh v22.4s, v22.4s, v31.4s\n"
-    "sqadd v20.4s, v20.4s, v4.4s\n"
-    "smin v15.4s, v15.4s, v14.4s\n"
-    "srshl v10.4s, v10.4s, v9.4s\n"
-    "sqadd v23.4s, v23.4s, v2.4s\n"
-    "smax v15.4s, v15.4s, v19.4s\n"
-    "srshl v20.4s, v20.4s, v30.4s\n"
-    "add v10.4s, v10.4s, v11.4s\n"
-    "srshl v23.4s, v23.4s, v9.4s\n"
-    "sqadd v16.4s, v16.4s, v1.4s\n"
-    "smin v10.4s, v10.4s, v14.4s\n"
-    "add v20.4s, v20.4s, v11.4s\n"
-    "add v23.4s, v23.4s, v11.4s\n"
-    "smax v10.4s, v10.4s, v19.4s\n"
-    "smin v20.4s, v20.4s, v14.4s\n"
-    "smin v23.4s, v23.4s, v14.4s\n"
-    "uzp1 v15.16b, v15.16b, v10.16b\n"
-    "smax v20.4s, v20.4s, v19.4s\n"
+    "sqadd v22.4s, v22.4s, v29.4s\n"
+    "sshr v5.4s, v5.4s, #0x1f\n"
+    "sqadd v23.4s, v23.4s, v20.4s\n"
+    "sshr v26.4s, v26.4s, #0x1f\n"
+    "srshl v15.4s, v15.4s, v0.4s\n"
+    "sqadd v10.4s, v10.4s, v19.4s\n"
+    "srshl v9.4s, v9.4s, v0.4s\n"
+    "sqadd v16.4s, v16.4s, v4.4s\n"
+    "srshl v22.4s, v22.4s, v0.4s\n"
+    "sqadd v21.4s, v21.4s, v5.4s\n"
+    "srshl v23.4s, v23.4s, v0.4s\n"
+    "sqadd v18.4s, v18.4s, v26.4s\n"
+    "srshl v10.4s, v10.4s, v31.4s\n"
+    "sqxtn v15.4h, v15.4s\n"
+    "srshl v16.4s, v16.4s, v31.4s\n"
+    "sqxtn v9.4h, v9.4s\n"
+    "srshl v21.4s, v21.4s, v31.4s\n"
+    "sqxtn v22.4h, v22.4s\n"
+    "srshl v18.4s, v18.4s, v31.4s\n"
+    "sqxtn v23.4h, v23.4s\n"
+    "sqxtn2 v15.8h, v10.4s\n"
+    "sqxtn2 v9.8h, v16.4s\n"
+    "sqxtn2 v22.8h, v21.4s\n"
+    "sqxtn2 v23.8h, v18.4s\n"
+    "sqadd v15.8h, v15.8h, v11.8h\n"
+    "sqadd v9.8h, v9.8h, v11.8h\n"
+    "sqadd v22.8h, v22.8h, v11.8h\n"
+    "sqadd v23.8h, v23.8h, v11.8h\n"
+    "smax v15.8h, v15.8h, v17.8h\n"
+    "smax v9.8h, v9.8h, v17.8h\n"
+    "smax v22.8h, v22.8h, v17.8h\n"
+    "smax v23.8h, v23.8h, v17.8h\n"
+    "smin v15.8h, v15.8h, v14.8h\n"
+    "smin v9.8h, v9.8h, v14.8h\n"
+    "smin v22.8h, v22.8h, v14.8h\n"
+    "smin v23.8h, v23.8h, v14.8h\n"
     "uzp1 v15.16b, v15.16b, v15.16b\n"
-    "str d15, [x14, x7]\n"
-    "smax v23.4s, v23.4s, v19.4s\n"
-    "srshl v16.4s, v16.4s, v30.4s\n"
-    "and v24.16b, v22.16b, v9.16b\n"
-    "sshr v24.4s, v24.4s, #0x1f\n"
-    "uzp1 v20.16b, v20.16b, v23.16b\n"
-    "add v16.4s, v16.4s, v11.4s\n"
-    "sqrdmulh v17.4s, v17.4s, v21.4s\n"
-    "uzp1 v20.16b, v20.16b, v20.16b\n"
-    "str d20, [x13, x7]\n"
-    "smin v16.4s, v16.4s, v14.4s\n"
-    "sqrdmulh v18.4s, v18.4s, v31.4s\n"
-    "sqadd v22.4s, v22.4s, v24.4s\n"
-    "and v2.16b, v17.16b, v30.16b\n"
-    "sshr v2.4s, v2.4s, #0x1f\n"
-    "smax v16.4s, v16.4s, v19.4s\n"
-    "srshl v22.4s, v22.4s, v9.4s\n"
-    "and v31.16b, v18.16b, v9.16b\n"
-    "sshr v31.4s, v31.4s, #0x1f\n"
-    "add v22.4s, v22.4s, v11.4s\n"
-    "sqadd v17.4s, v17.4s, v2.4s\n"
-    "smin v22.4s, v22.4s, v14.4s\n"
-    "srshl v17.4s, v17.4s, v30.4s\n"
-    "sqadd v18.4s, v18.4s, v31.4s\n"
-    "smax v22.4s, v22.4s, v19.4s\n"
-    "uzp1 v16.16b, v16.16b, v22.16b\n"
-    "add v17.4s, v17.4s, v11.4s\n"
-    "srshl v18.4s, v18.4s, v9.4s\n"
-    "uzp1 v16.16b, v16.16b, v16.16b\n"
-    "str d16, [x12, x7]\n"
-    "smin v17.4s, v17.4s, v14.4s\n"
-    "add v18.4s, v18.4s, v11.4s\n"
-    "smax v17.4s, v17.4s, v19.4s\n"
-    "smin v18.4s, v18.4s, v14.4s\n"
-    "smax v18.4s, v18.4s, v19.4s\n"
-    "uzp1 v17.16b, v17.16b, v18.16b\n"
-    "uzp1 v17.16b, v17.16b, v17.16b\n"
-    "str d17, [x11, x7]\n"
-    "add x7, x7, #0x8\n"
+    "str d15, [x10, x14]\n"
+    "uzp1 v9.16b, v9.16b, v9.16b\n"
+    "uzp1 v22.16b, v22.16b, v22.16b\n"
+    "str d9, [x9, x14]\n"
+    "uzp1 v23.16b, v23.16b, v23.16b\n"
+    "str d22, [x28, x14]\n"
+    "str d23, [x27, x14]\n"
+    "add x14, x14, #0x8\n"
     "beq 88f\n"
-    "add x6, x6, #0x48\n"
+    "add x17, x17, #0x48\n"
     "3:"  // Oddments
     "ldr x19, [%x[params], %[offsetof_Params_bias]]\n"
-    "tbz x4, #2, 5f\n"
+    "tbz x8, #2, 5f\n"
     "ld1 { v15.4s }, [x19], #0x10\n"
-    "tbz x4, #1, 4f\n"
+    "tbz x8, #1, 4f\n"
     "ld1 { v10.d }[0], [x19], #0x8\n"
-    "tbz x4, #0, 7f\n"
+    "tbz x8, #0, 7f\n"
     "ld1 { v10.s }[2], [x19]\n"
     "b 7f\n"
     "4:"  // Oddments: Load bias: Bit 2: Bit 1: Unset
-    "tbz x4, #0, 7f\n"
+    "tbz x8, #0, 7f\n"
     "ld1 { v10.s }[0], [x19]\n"
     "b 7f\n"
     "5:"  // Oddments: Load bias: Bit 2: Unset
-    "tbz x4, #1, 6f\n"
+    "tbz x8, #1, 6f\n"
     "ld1 { v15.d }[0], [x19], #0x8\n"
-    "tbz x4, #0, 7f\n"
+    "tbz x8, #0, 7f\n"
     "ld1 { v15.s }[2], [x19]\n"
     "b 7f\n"
     "6:"  // Oddments: Load bias: Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 7f\n"
+    "tbz x8, #0, 7f\n"
     "ld1 { v15.s }[0], [x19]\n"
     "7:"  // Oddments: Load bias: Bit 2: End
-    "mov v20.16b, v15.16b\n"
-    "ldr d0, [x6, #0x0]\n"
-    "mov v23.16b, v10.16b\n"
-    "ldr d1, [x6, #0x8]\n"
-    "mov v16.16b, v15.16b\n"
-    "ldr d2, [x6, #0x10]\n"
-    "mov v22.16b, v10.16b\n"
-    "ldr d3, [x6, #0x18]\n"
-    "mov v17.16b, v15.16b\n"
-    "ldr d4, [x6, #0x20]\n"
-    "ssubl v0.8h, v0.8b, v13.8b\n"
+    "ldr d0, [x17, #0x0]\n"
+    "ldr d1, [x17, #0x8]\n"
+    "mov v9.16b, v15.16b\n"
+    "mov v16.16b, v10.16b\n"
+    "ldr d2, [x17, #0x10]\n"
+    "ldr d3, [x17, #0x18]\n"
+    "mov v22.16b, v15.16b\n"
+    "mov v21.16b, v10.16b\n"
+    "ldr d4, [x17, #0x20]\n"
+    "ldr d5, [x17, #0x28]\n"
+    "mov v23.16b, v15.16b\n"
     "mov v18.16b, v10.16b\n"
-    "ldr d5, [x6, #0x28]\n"
+    "ldr d6, [x17, #0x30]\n"
+    "ldr d7, [x17, #0x38]\n"
+    "ssubl v0.8h, v0.8b, v13.8b\n"
     "ssubl v1.8h, v1.8b, v13.8b\n"
-    "ldr d6, [x6, #0x30]\n"
+    "ldr d8, [x17, #0x40]\n"
+    "ldp x26, x25, [x12, #0x0]\n"
     "ssubl v2.8h, v2.8b, v13.8b\n"
-    "ldr d7, [x6, #0x38]\n"
     "ssubl v3.8h, v3.8b, v13.8b\n"
-    "ldr d8, [x6, #0x40]\n"
+    "ldp x24, x23, [x12, #0x10]\n"
+    "ldp x22, x21, [x12, #0x20]\n"
     "ssubl v4.8h, v4.8b, v13.8b\n"
-    "ldp x26, x25, [x8, #0x0]\n"
     "ssubl v5.8h, v5.8b, v13.8b\n"
-    "ldp x24, x23, [x8, #0x10]\n"
+    "ldp x20, x19, [x12, #0x30]\n"
     "ssubl v6.8h, v6.8b, v13.8b\n"
     "ssubl v7.8h, v7.8b, v13.8b\n"
-    "ldp x22, x21, [x8, #0x20]\n"
     "ssubl v8.8h, v8.8b, v13.8b\n"
-    "ldp x20, x19, [x8, #0x30]\n"
-    "add x26, x26, x5\n"
-    "add x25, x25, x5\n"
-    "add x24, x24, x5\n"
-    "add x23, x23, x5\n"
-    "add x22, x22, x5\n"
-    "add x21, x21, x5\n"
-    "add x20, x20, x5\n"
-    "add x19, x19, x5\n"
-    "tbz x4, #2, 9f\n"
+    "add x26, x26, x15\n"
+    "add x25, x25, x15\n"
+    "add x24, x24, x15\n"
+    "add x23, x23, x15\n"
+    "add x22, x22, x15\n"
+    "add x21, x21, x15\n"
+    "add x20, x20, x15\n"
+    "add x19, x19, x15\n"
+    "tbz x8, #2, 9f\n"
     "ld1 { v31.s }[0], [x26], #0x4\n"
     "ld1 { v30.s }[0], [x25], #0x4\n"
     "ld1 { v29.s }[0], [x24], #0x4\n"
@@ -713,7 +697,7 @@
     "ld1 { v26.s }[0], [x21], #0x4\n"
     "ld1 { v25.s }[0], [x20], #0x4\n"
     "ld1 { v24.s }[0], [x19], #0x4\n"
-    "tbz x4, #1, 8f\n"
+    "tbz x8, #1, 8f\n"
     "ld1 { v31.h }[2], [x26], #0x2\n"
     "ld1 { v30.h }[2], [x25], #0x2\n"
     "ld1 { v29.h }[2], [x24], #0x2\n"
@@ -722,7 +706,7 @@
     "ld1 { v26.h }[2], [x21], #0x2\n"
     "ld1 { v25.h }[2], [x20], #0x2\n"
     "ld1 { v24.h }[2], [x19], #0x2\n"
-    "tbz x4, #0, 11f\n"
+    "tbz x8, #0, 11f\n"
     "ld1 { v31.b }[6], [x26]\n"
     "ld1 { v30.b }[6], [x25]\n"
     "ld1 { v29.b }[6], [x24]\n"
@@ -733,7 +717,7 @@
     "ld1 { v24.b }[6], [x19]\n"
     "b 11f\n"
     "8:"  // Oddments: Initial loads: Bit 2: Bit 1: Unset
-    "tbz x4, #0, 11f\n"
+    "tbz x8, #0, 11f\n"
     "ld1 { v31.b }[4], [x26]\n"
     "ld1 { v30.b }[4], [x25]\n"
     "ld1 { v29.b }[4], [x24]\n"
@@ -744,7 +728,7 @@
     "ld1 { v24.b }[4], [x19]\n"
     "b 11f\n"
     "9:"  // Oddments: Initial loads: Bit 2: Unset
-    "tbz x4, #1, 10f\n"
+    "tbz x8, #1, 10f\n"
     "ld1 { v31.h }[0], [x26], #0x2\n"
     "ld1 { v30.h }[0], [x25], #0x2\n"
     "ld1 { v29.h }[0], [x24], #0x2\n"
@@ -753,7 +737,7 @@
     "ld1 { v26.h }[0], [x21], #0x2\n"
     "ld1 { v25.h }[0], [x20], #0x2\n"
     "ld1 { v24.h }[0], [x19], #0x2\n"
-    "tbz x4, #0, 11f\n"
+    "tbz x8, #0, 11f\n"
     "ld1 { v31.b }[2], [x26]\n"
     "ld1 { v30.b }[2], [x25]\n"
     "ld1 { v29.b }[2], [x24]\n"
@@ -764,7 +748,7 @@
     "ld1 { v24.b }[2], [x19]\n"
     "b 11f\n"
     "10:"  // Oddments: Initial loads: Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 11f\n"
+    "tbz x8, #0, 11f\n"
     "ld1 { v31.b }[0], [x26]\n"
     "ld1 { v30.b }[0], [x25]\n"
     "ld1 { v29.b }[0], [x24]\n"
@@ -774,646 +758,636 @@
     "ld1 { v25.b }[0], [x20]\n"
     "ld1 { v24.b }[0], [x19]\n"
     "11:"  // Oddments: Initial loads: Bit 2: End
-    "ldr x23, [x8, #0x40]\n"
     "ssubl v31.8h, v31.8b, v12.8b\n"
     "smlal v15.4s, v31.4h, v8.4h\n"
-    "ssubl v30.8h, v30.8b, v12.8b\n"
     "smlal2 v10.4s, v31.8h, v8.8h\n"
-    "ssubl v29.8h, v29.8b, v12.8b\n"
-    "smlal v20.4s, v31.4h, v6.4h\n"
-    "ssubl v28.8h, v28.8b, v12.8b\n"
-    "smlal2 v23.4s, v31.8h, v6.8h\n"
-    "ssubl v27.8h, v27.8b, v12.8b\n"
-    "smlal v16.4s, v31.4h, v2.4h\n"
-    "ssubl v26.8h, v26.8b, v12.8b\n"
-    "smlal2 v22.4s, v31.8h, v2.8h\n"
-    "ssubl v25.8h, v25.8b, v12.8b\n"
-    "smlal v17.4s, v31.4h, v0.4h\n"
-    "ssubl v24.8h, v24.8b, v12.8b\n"
-    "smlal2 v18.4s, v31.8h, v0.8h\n"
-    "add x23, x23, x5\n"
+    "ldr x24, [x12, #0x40]\n"
+    "ssubl v30.8h, v30.8b, v12.8b\n"
     "smlal v15.4s, v30.4h, v0.4h\n"
     "smlal2 v10.4s, v30.8h, v0.8h\n"
-    "smlal v20.4s, v28.4h, v1.4h\n"
-    "smlal2 v23.4s, v28.8h, v1.8h\n"
+    "add x24, x24, x15\n"
+    "ssubl v29.8h, v29.8b, v12.8b\n"
+    "smlal v9.4s, v31.4h, v6.4h\n"
+    "smlal2 v16.4s, v31.8h, v6.8h\n"
     "smlal v15.4s, v29.4h, v1.4h\n"
     "smlal2 v10.4s, v29.8h, v1.8h\n"
-    "smlal v20.4s, v27.4h, v2.4h\n"
-    "smlal2 v23.4s, v27.8h, v2.8h\n"
+    "ssubl v28.8h, v28.8b, v12.8b\n"
+    "ssubl v26.8h, v26.8b, v12.8b\n"
+    "smlal v9.4s, v28.4h, v1.4h\n"
+    "smlal2 v16.4s, v28.8h, v1.8h\n"
     "smlal v15.4s, v26.4h, v3.4h\n"
     "smlal2 v10.4s, v26.8h, v3.8h\n"
-    "smlal v20.4s, v24.4h, v0.4h\n"
-    "smlal2 v23.4s, v24.8h, v0.8h\n"
+    "ssubl v27.8h, v27.8b, v12.8b\n"
+    "ssubl v25.8h, v25.8b, v12.8b\n"
+    "smlal v9.4s, v27.4h, v2.4h\n"
+    "smlal2 v16.4s, v27.8h, v2.8h\n"
     "smlal v15.4s, v25.4h, v4.4h\n"
     "smlal2 v10.4s, v25.8h, v4.8h\n"
+    "ssubl v24.8h, v24.8b, v12.8b\n"
+    "smlal v22.4s, v31.4h, v2.4h\n"
+    "smlal2 v21.4s, v31.8h, v2.8h\n"
+    "smlal v23.4s, v31.4h, v0.4h\n"
+    "smlal2 v18.4s, v31.8h, v0.8h\n"
     "smlal v15.4s, v24.4h, v2.4h\n"
     "smlal2 v10.4s, v24.8h, v2.8h\n"
-    "tbz x4, #2, 13f\n"
-    "ld1 { v29.s }[0], [x23], #0x4\n"
-    "tbz x4, #1, 12f\n"
-    "ld1 { v29.h }[2], [x23], #0x2\n"
-    "tbz x4, #0, 15f\n"
-    "ld1 { v29.b }[6], [x23]\n"
+    "smlal v9.4s, v24.4h, v0.4h\n"
+    "smlal2 v16.4s, v24.8h, v0.8h\n"
+    "tbz x8, #2, 13f\n"
+    "ld1 { v29.s }[0], [x24], #0x4\n"
+    "tbz x8, #1, 12f\n"
+    "ld1 { v29.h }[2], [x24], #0x2\n"
+    "tbz x8, #0, 15f\n"
+    "ld1 { v29.b }[6], [x24]\n"
     "b 15f\n"
     "12:"  // Oddments: Load (1, 3): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 15f\n"
-    "ld1 { v29.b }[4], [x23]\n"
+    "tbz x8, #0, 15f\n"
+    "ld1 { v29.b }[4], [x24]\n"
     "b 15f\n"
     "13:"  // Oddments: Load (1, 3): Bit 2: Unset
-    "tbz x4, #1, 14f\n"
-    "ld1 { v29.h }[0], [x23], #0x2\n"
-    "tbz x4, #0, 15f\n"
-    "ld1 { v29.b }[2], [x23]\n"
+    "tbz x8, #1, 14f\n"
+    "ld1 { v29.h }[0], [x24], #0x2\n"
+    "tbz x8, #0, 15f\n"
+    "ld1 { v29.b }[2], [x24]\n"
     "b 15f\n"
     "14:"  // Oddments: Load (1, 3): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 15f\n"
-    "ld1 { v29.b }[0], [x23]\n"
+    "tbz x8, #0, 15f\n"
+    "ld1 { v29.b }[0], [x24]\n"
     "15:"  // Oddments: Load (1, 3): Bit 2: End
-    "ldr x22, [x8, #0x48]\n"
     "ssubl v29.8h, v29.8b, v12.8b\n"
-    "smlal v20.4s, v29.4h, v4.4h\n"
-    "smlal2 v23.4s, v29.8h, v4.8h\n"
-    "add x22, x22, x5\n"
-    "tbz x4, #2, 17f\n"
-    "ld1 { v28.s }[0], [x22], #0x4\n"
-    "tbz x4, #1, 16f\n"
-    "ld1 { v28.h }[2], [x22], #0x2\n"
-    "tbz x4, #0, 19f\n"
-    "ld1 { v28.b }[6], [x22]\n"
+    "ldr x23, [x12, #0x48]\n"
+    "smlal v9.4s, v29.4h, v4.4h\n"
+    "smlal2 v16.4s, v29.8h, v4.8h\n"
+    "add x23, x23, x15\n"
+    "tbz x8, #2, 17f\n"
+    "ld1 { v28.s }[0], [x23], #0x4\n"
+    "tbz x8, #1, 16f\n"
+    "ld1 { v28.h }[2], [x23], #0x2\n"
+    "tbz x8, #0, 19f\n"
+    "ld1 { v28.b }[6], [x23]\n"
     "b 19f\n"
     "16:"  // Oddments: Load (1, 4): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 19f\n"
-    "ld1 { v28.b }[4], [x22]\n"
+    "tbz x8, #0, 19f\n"
+    "ld1 { v28.b }[4], [x23]\n"
     "b 19f\n"
     "17:"  // Oddments: Load (1, 4): Bit 2: Unset
-    "tbz x4, #1, 18f\n"
-    "ld1 { v28.h }[0], [x22], #0x2\n"
-    "tbz x4, #0, 19f\n"
-    "ld1 { v28.b }[2], [x22]\n"
+    "tbz x8, #1, 18f\n"
+    "ld1 { v28.h }[0], [x23], #0x2\n"
+    "tbz x8, #0, 19f\n"
+    "ld1 { v28.b }[2], [x23]\n"
     "b 19f\n"
     "18:"  // Oddments: Load (1, 4): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 19f\n"
-    "ld1 { v28.b }[0], [x22]\n"
+    "tbz x8, #0, 19f\n"
+    "ld1 { v28.b }[0], [x23]\n"
     "19:"  // Oddments: Load (1, 4): Bit 2: End
-    "ldr x21, [x8, #0x50]\n"
     "ssubl v28.8h, v28.8b, v12.8b\n"
-    "smlal v20.4s, v28.4h, v5.4h\n"
-    "smlal2 v23.4s, v28.8h, v5.8h\n"
-    "add x21, x21, x5\n"
-    "tbz x4, #2, 21f\n"
+    "ldr x21, [x12, #0x50]\n"
+    "smlal v9.4s, v28.4h, v5.4h\n"
+    "smlal2 v16.4s, v28.8h, v5.8h\n"
+    "add x21, x21, x15\n"
+    "tbz x8, #2, 21f\n"
     "ld1 { v27.s }[0], [x21], #0x4\n"
-    "tbz x4, #1, 20f\n"
+    "tbz x8, #1, 20f\n"
     "ld1 { v27.h }[2], [x21], #0x2\n"
-    "tbz x4, #0, 23f\n"
+    "tbz x8, #0, 23f\n"
     "ld1 { v27.b }[6], [x21]\n"
     "b 23f\n"
     "20:"  // Oddments: Load (1, 2): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 23f\n"
+    "tbz x8, #0, 23f\n"
     "ld1 { v27.b }[4], [x21]\n"
     "b 23f\n"
     "21:"  // Oddments: Load (1, 2): Bit 2: Unset
-    "tbz x4, #1, 22f\n"
+    "tbz x8, #1, 22f\n"
     "ld1 { v27.h }[0], [x21], #0x2\n"
-    "tbz x4, #0, 23f\n"
+    "tbz x8, #0, 23f\n"
     "ld1 { v27.b }[2], [x21]\n"
     "b 23f\n"
     "22:"  // Oddments: Load (1, 2): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 23f\n"
+    "tbz x8, #0, 23f\n"
     "ld1 { v27.b }[0], [x21]\n"
     "23:"  // Oddments: Load (1, 2): Bit 2: End
-    "ldr x20, [x8, #0x58]\n"
     "ssubl v27.8h, v27.8b, v12.8b\n"
+    "ldr x19, [x12, #0x58]\n"
     "smlal v15.4s, v27.4h, v5.4h\n"
     "smlal2 v10.4s, v27.8h, v5.8h\n"
-    "add x20, x20, x5\n"
-    "smlal v20.4s, v27.4h, v3.4h\n"
-    "smlal2 v23.4s, v27.8h, v3.8h\n"
-    "tbz x4, #2, 25f\n"
-    "ld1 { v26.s }[0], [x20], #0x4\n"
-    "tbz x4, #1, 24f\n"
-    "ld1 { v26.h }[2], [x20], #0x2\n"
-    "tbz x4, #0, 27f\n"
-    "ld1 { v26.b }[6], [x20]\n"
+    "smlal v9.4s, v27.4h, v3.4h\n"
+    "smlal2 v16.4s, v27.8h, v3.8h\n"
+    "add x19, x19, x15\n"
+    "tbz x8, #2, 25f\n"
+    "ld1 { v26.s }[0], [x19], #0x4\n"
+    "tbz x8, #1, 24f\n"
+    "ld1 { v26.h }[2], [x19], #0x2\n"
+    "tbz x8, #0, 27f\n"
+    "ld1 { v26.b }[6], [x19]\n"
     "b 27f\n"
     "24:"  // Oddments: Load (3, 0): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 27f\n"
-    "ld1 { v26.b }[4], [x20]\n"
+    "tbz x8, #0, 27f\n"
+    "ld1 { v26.b }[4], [x19]\n"
     "b 27f\n"
     "25:"  // Oddments: Load (3, 0): Bit 2: Unset
-    "tbz x4, #1, 26f\n"
-    "ld1 { v26.h }[0], [x20], #0x2\n"
-    "tbz x4, #0, 27f\n"
-    "ld1 { v26.b }[2], [x20]\n"
+    "tbz x8, #1, 26f\n"
+    "ld1 { v26.h }[0], [x19], #0x2\n"
+    "tbz x8, #0, 27f\n"
+    "ld1 { v26.b }[2], [x19]\n"
     "b 27f\n"
     "26:"  // Oddments: Load (3, 0): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 27f\n"
-    "ld1 { v26.b }[0], [x20]\n"
+    "tbz x8, #0, 27f\n"
+    "ld1 { v26.b }[0], [x19]\n"
     "27:"  // Oddments: Load (3, 0): Bit 2: End
-    "ldr x19, [x8, #0x60]\n"
     "ssubl v26.8h, v26.8b, v12.8b\n"
-    "smlal v16.4s, v26.4h, v3.4h\n"
-    "smlal2 v22.4s, v26.8h, v3.8h\n"
-    "add x19, x19, x5\n"
-    "tbz x4, #2, 29f\n"
-    "ld1 { v25.s }[0], [x19], #0x4\n"
-    "tbz x4, #1, 28f\n"
-    "ld1 { v25.h }[2], [x19], #0x2\n"
-    "tbz x4, #0, 31f\n"
-    "ld1 { v25.b }[6], [x19]\n"
+    "ldr x20, [x12, #0x60]\n"
+    "smlal v22.4s, v26.4h, v3.4h\n"
+    "smlal2 v21.4s, v26.8h, v3.8h\n"
+    "add x20, x20, x15\n"
+    "tbz x8, #2, 29f\n"
+    "ld1 { v25.s }[0], [x20], #0x4\n"
+    "tbz x8, #1, 28f\n"
+    "ld1 { v25.h }[2], [x20], #0x2\n"
+    "tbz x8, #0, 31f\n"
+    "ld1 { v25.b }[6], [x20]\n"
     "b 31f\n"
     "28:"  // Oddments: Load (2, 0): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 31f\n"
-    "ld1 { v25.b }[4], [x19]\n"
+    "tbz x8, #0, 31f\n"
+    "ld1 { v25.b }[4], [x20]\n"
     "b 31f\n"
     "29:"  // Oddments: Load (2, 0): Bit 2: Unset
-    "tbz x4, #1, 30f\n"
-    "ld1 { v25.h }[0], [x19], #0x2\n"
-    "tbz x4, #0, 31f\n"
-    "ld1 { v25.b }[2], [x19]\n"
+    "tbz x8, #1, 30f\n"
+    "ld1 { v25.h }[0], [x20], #0x2\n"
+    "tbz x8, #0, 31f\n"
+    "ld1 { v25.b }[2], [x20]\n"
     "b 31f\n"
     "30:"  // Oddments: Load (2, 0): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 31f\n"
-    "ld1 { v25.b }[0], [x19]\n"
+    "tbz x8, #0, 31f\n"
+    "ld1 { v25.b }[0], [x20]\n"
     "31:"  // Oddments: Load (2, 0): Bit 2: End
-    "ldr x10, [x8, #0x68]\n"
     "ssubl v25.8h, v25.8b, v12.8b\n"
+    "ldr x19, [x12, #0x68]\n"
     "smlal v15.4s, v25.4h, v6.4h\n"
     "smlal2 v10.4s, v25.8h, v6.8h\n"
-    "add x10, x10, x5\n"
-    "smlal v16.4s, v25.4h, v0.4h\n"
-    "smlal2 v22.4s, v25.8h, v0.8h\n"
-    "tbz x4, #2, 33f\n"
-    "ld1 { v29.s }[0], [x10], #0x4\n"
-    "tbz x4, #1, 32f\n"
-    "ld1 { v29.h }[2], [x10], #0x2\n"
-    "tbz x4, #0, 35f\n"
-    "ld1 { v29.b }[6], [x10]\n"
+    "smlal v22.4s, v25.4h, v0.4h\n"
+    "smlal2 v21.4s, v25.8h, v0.8h\n"
+    "add x19, x19, x15\n"
+    "tbz x8, #2, 33f\n"
+    "ld1 { v29.s }[0], [x19], #0x4\n"
+    "tbz x8, #1, 32f\n"
+    "ld1 { v29.h }[2], [x19], #0x2\n"
+    "tbz x8, #0, 35f\n"
+    "ld1 { v29.b }[6], [x19]\n"
     "b 35f\n"
     "32:"  // Oddments: Load (3, 1): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 35f\n"
-    "ld1 { v29.b }[4], [x10]\n"
+    "tbz x8, #0, 35f\n"
+    "ld1 { v29.b }[4], [x19]\n"
     "b 35f\n"
     "33:"  // Oddments: Load (3, 1): Bit 2: Unset
-    "tbz x4, #1, 34f\n"
-    "ld1 { v29.h }[0], [x10], #0x2\n"
-    "tbz x4, #0, 35f\n"
-    "ld1 { v29.b }[2], [x10]\n"
+    "tbz x8, #1, 34f\n"
+    "ld1 { v29.h }[0], [x19], #0x2\n"
+    "tbz x8, #0, 35f\n"
+    "ld1 { v29.b }[2], [x19]\n"
     "b 35f\n"
     "34:"  // Oddments: Load (3, 1): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 35f\n"
-    "ld1 { v29.b }[0], [x10]\n"
+    "tbz x8, #0, 35f\n"
+    "ld1 { v29.b }[0], [x19]\n"
     "35:"  // Oddments: Load (3, 1): Bit 2: End
-    "ldr x9, [x8, #0x70]\n"
     "ssubl v29.8h, v29.8b, v12.8b\n"
-    "smlal v16.4s, v29.4h, v4.4h\n"
-    "smlal2 v22.4s, v29.8h, v4.8h\n"
-    "add x9, x9, x5\n"
-    "tbz x4, #2, 37f\n"
-    "ld1 { v24.s }[0], [x9], #0x4\n"
-    "tbz x4, #1, 36f\n"
-    "ld1 { v24.h }[2], [x9], #0x2\n"
-    "tbz x4, #0, 39f\n"
-    "ld1 { v24.b }[6], [x9]\n"
+    "ldr x19, [x12, #0x70]\n"
+    "smlal v22.4s, v29.4h, v4.4h\n"
+    "smlal2 v21.4s, v29.8h, v4.8h\n"
+    "add x19, x19, x15\n"
+    "tbz x8, #2, 37f\n"
+    "ld1 { v24.s }[0], [x19], #0x4\n"
+    "tbz x8, #1, 36f\n"
+    "ld1 { v24.h }[2], [x19], #0x2\n"
+    "tbz x8, #0, 39f\n"
+    "ld1 { v24.b }[6], [x19]\n"
     "b 39f\n"
     "36:"  // Oddments: Load (2, 1): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 39f\n"
-    "ld1 { v24.b }[4], [x9]\n"
+    "tbz x8, #0, 39f\n"
+    "ld1 { v24.b }[4], [x19]\n"
     "b 39f\n"
     "37:"  // Oddments: Load (2, 1): Bit 2: Unset
-    "tbz x4, #1, 38f\n"
-    "ld1 { v24.h }[0], [x9], #0x2\n"
-    "tbz x4, #0, 39f\n"
-    "ld1 { v24.b }[2], [x9]\n"
+    "tbz x8, #1, 38f\n"
+    "ld1 { v24.h }[0], [x19], #0x2\n"
+    "tbz x8, #0, 39f\n"
+    "ld1 { v24.b }[2], [x19]\n"
     "b 39f\n"
     "38:"  // Oddments: Load (2, 1): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 39f\n"
-    "ld1 { v24.b }[0], [x9]\n"
+    "tbz x8, #0, 39f\n"
+    "ld1 { v24.b }[0], [x19]\n"
     "39:"  // Oddments: Load (2, 1): Bit 2: End
-    "ldr x28, [x8, #0x78]\n"
     "ssubl v24.8h, v24.8b, v12.8b\n"
+    "ldr x22, [x12, #0x78]\n"
     "smlal v15.4s, v24.4h, v7.4h\n"
     "smlal2 v10.4s, v24.8h, v7.8h\n"
-    "add x28, x28, x5\n"
-    "smlal v16.4s, v24.4h, v1.4h\n"
-    "smlal2 v22.4s, v24.8h, v1.8h\n"
-    "tbz x4, #2, 41f\n"
-    "ld1 { v27.s }[0], [x28], #0x4\n"
-    "tbz x4, #1, 40f\n"
-    "ld1 { v27.h }[2], [x28], #0x2\n"
-    "tbz x4, #0, 43f\n"
-    "ld1 { v27.b }[6], [x28]\n"
+    "smlal v22.4s, v24.4h, v1.4h\n"
+    "smlal2 v21.4s, v24.8h, v1.8h\n"
+    "add x22, x22, x15\n"
+    "tbz x8, #2, 41f\n"
+    "ld1 { v27.s }[0], [x22], #0x4\n"
+    "tbz x8, #1, 40f\n"
+    "ld1 { v27.h }[2], [x22], #0x2\n"
+    "tbz x8, #0, 43f\n"
+    "ld1 { v27.b }[6], [x22]\n"
     "b 43f\n"
     "40:"  // Oddments: Load (3, 3): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 43f\n"
-    "ld1 { v27.b }[4], [x28]\n"
+    "tbz x8, #0, 43f\n"
+    "ld1 { v27.b }[4], [x22]\n"
     "b 43f\n"
     "41:"  // Oddments: Load (3, 3): Bit 2: Unset
-    "tbz x4, #1, 42f\n"
-    "ld1 { v27.h }[0], [x28], #0x2\n"
-    "tbz x4, #0, 43f\n"
-    "ld1 { v27.b }[2], [x28]\n"
+    "tbz x8, #1, 42f\n"
+    "ld1 { v27.h }[0], [x22], #0x2\n"
+    "tbz x8, #0, 43f\n"
+    "ld1 { v27.b }[2], [x22]\n"
     "b 43f\n"
     "42:"  // Oddments: Load (3, 3): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 43f\n"
-    "ld1 { v27.b }[0], [x28]\n"
+    "tbz x8, #0, 43f\n"
+    "ld1 { v27.b }[0], [x22]\n"
     "43:"  // Oddments: Load (3, 3): Bit 2: End
-    "ldr x27, [x8, #0x80]\n"
     "ssubl v27.8h, v27.8b, v12.8b\n"
-    "smlal v17.4s, v27.4h, v4.4h\n"
+    "ldr x21, [x12, #0x80]\n"
+    "smlal v23.4s, v27.4h, v4.4h\n"
     "smlal2 v18.4s, v27.8h, v4.8h\n"
-    "add x27, x27, x5\n"
-    "tbz x4, #2, 45f\n"
-    "ld1 { v28.s }[0], [x27], #0x4\n"
-    "tbz x4, #1, 44f\n"
-    "ld1 { v28.h }[2], [x27], #0x2\n"
-    "tbz x4, #0, 47f\n"
-    "ld1 { v28.b }[6], [x27]\n"
+    "add x21, x21, x15\n"
+    "tbz x8, #2, 45f\n"
+    "ld1 { v28.s }[0], [x21], #0x4\n"
+    "tbz x8, #1, 44f\n"
+    "ld1 { v28.h }[2], [x21], #0x2\n"
+    "tbz x8, #0, 47f\n"
+    "ld1 { v28.b }[6], [x21]\n"
     "b 47f\n"
     "44:"  // Oddments: Load (2, 3): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 47f\n"
-    "ld1 { v28.b }[4], [x27]\n"
+    "tbz x8, #0, 47f\n"
+    "ld1 { v28.b }[4], [x21]\n"
     "b 47f\n"
     "45:"  // Oddments: Load (2, 3): Bit 2: Unset
-    "tbz x4, #1, 46f\n"
-    "ld1 { v28.h }[0], [x27], #0x2\n"
-    "tbz x4, #0, 47f\n"
-    "ld1 { v28.b }[2], [x27]\n"
+    "tbz x8, #1, 46f\n"
+    "ld1 { v28.h }[0], [x21], #0x2\n"
+    "tbz x8, #0, 47f\n"
+    "ld1 { v28.b }[2], [x21]\n"
     "b 47f\n"
     "46:"  // Oddments: Load (2, 3): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 47f\n"
-    "ld1 { v28.b }[0], [x27]\n"
+    "tbz x8, #0, 47f\n"
+    "ld1 { v28.b }[0], [x21]\n"
     "47:"  // Oddments: Load (2, 3): Bit 2: End
-    "ldr x26, [x8, #0x88]\n"
     "ssubl v28.8h, v28.8b, v12.8b\n"
-    "smlal v20.4s, v28.4h, v7.4h\n"
-    "smlal2 v23.4s, v28.8h, v7.8h\n"
-    "add x26, x26, x5\n"
-    "smlal v17.4s, v28.4h, v1.4h\n"
+    "ldr x20, [x12, #0x88]\n"
+    "smlal v9.4s, v28.4h, v7.4h\n"
+    "smlal2 v16.4s, v28.8h, v7.8h\n"
+    "smlal v23.4s, v28.4h, v1.4h\n"
     "smlal2 v18.4s, v28.8h, v1.8h\n"
-    "tbz x4, #2, 49f\n"
-    "ld1 { v26.s }[0], [x26], #0x4\n"
-    "tbz x4, #1, 48f\n"
-    "ld1 { v26.h }[2], [x26], #0x2\n"
-    "tbz x4, #0, 51f\n"
-    "ld1 { v26.b }[6], [x26]\n"
+    "add x20, x20, x15\n"
+    "tbz x8, #2, 49f\n"
+    "ld1 { v26.s }[0], [x20], #0x4\n"
+    "tbz x8, #1, 48f\n"
+    "ld1 { v26.h }[2], [x20], #0x2\n"
+    "tbz x8, #0, 51f\n"
+    "ld1 { v26.b }[6], [x20]\n"
     "b 51f\n"
     "48:"  // Oddments: Load (3, 4): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 51f\n"
-    "ld1 { v26.b }[4], [x26]\n"
+    "tbz x8, #0, 51f\n"
+    "ld1 { v26.b }[4], [x20]\n"
     "b 51f\n"
     "49:"  // Oddments: Load (3, 4): Bit 2: Unset
-    "tbz x4, #1, 50f\n"
-    "ld1 { v26.h }[0], [x26], #0x2\n"
-    "tbz x4, #0, 51f\n"
-    "ld1 { v26.b }[2], [x26]\n"
+    "tbz x8, #1, 50f\n"
+    "ld1 { v26.h }[0], [x20], #0x2\n"
+    "tbz x8, #0, 51f\n"
+    "ld1 { v26.b }[2], [x20]\n"
     "b 51f\n"
     "50:"  // Oddments: Load (3, 4): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 51f\n"
-    "ld1 { v26.b }[0], [x26]\n"
+    "tbz x8, #0, 51f\n"
+    "ld1 { v26.b }[0], [x20]\n"
     "51:"  // Oddments: Load (3, 4): Bit 2: End
-    "ldr x25, [x8, #0x90]\n"
     "ssubl v26.8h, v26.8b, v12.8b\n"
-    "smlal v17.4s, v26.4h, v5.4h\n"
+    "ldr x23, [x12, #0x90]\n"
+    "smlal v23.4s, v26.4h, v5.4h\n"
     "smlal2 v18.4s, v26.8h, v5.8h\n"
-    "add x25, x25, x5\n"
-    "tbz x4, #2, 53f\n"
-    "ld1 { v25.s }[0], [x25], #0x4\n"
-    "tbz x4, #1, 52f\n"
-    "ld1 { v25.h }[2], [x25], #0x2\n"
-    "tbz x4, #0, 55f\n"
-    "ld1 { v25.b }[6], [x25]\n"
+    "add x23, x23, x15\n"
+    "tbz x8, #2, 53f\n"
+    "ld1 { v25.s }[0], [x23], #0x4\n"
+    "tbz x8, #1, 52f\n"
+    "ld1 { v25.h }[2], [x23], #0x2\n"
+    "tbz x8, #0, 55f\n"
+    "ld1 { v25.b }[6], [x23]\n"
     "b 55f\n"
     "52:"  // Oddments: Load (4, 0): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 55f\n"
-    "ld1 { v25.b }[4], [x25]\n"
+    "tbz x8, #0, 55f\n"
+    "ld1 { v25.b }[4], [x23]\n"
     "b 55f\n"
     "53:"  // Oddments: Load (4, 0): Bit 2: Unset
-    "tbz x4, #1, 54f\n"
-    "ld1 { v25.h }[0], [x25], #0x2\n"
-    "tbz x4, #0, 55f\n"
-    "ld1 { v25.b }[2], [x25]\n"
+    "tbz x8, #1, 54f\n"
+    "ld1 { v25.h }[0], [x23], #0x2\n"
+    "tbz x8, #0, 55f\n"
+    "ld1 { v25.b }[2], [x23]\n"
     "b 55f\n"
     "54:"  // Oddments: Load (4, 0): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 55f\n"
-    "ld1 { v25.b }[0], [x25]\n"
+    "tbz x8, #0, 55f\n"
+    "ld1 { v25.b }[0], [x23]\n"
     "55:"  // Oddments: Load (4, 0): Bit 2: End
-    "ldr x24, [x8, #0x98]\n"
     "ssubl v25.8h, v25.8b, v12.8b\n"
-    "smlal v16.4s, v25.4h, v6.4h\n"
-    "smlal2 v22.4s, v25.8h, v6.8h\n"
-    "add x24, x24, x5\n"
-    "tbz x4, #2, 57f\n"
+    "ldr x24, [x12, #0x98]\n"
+    "smlal v22.4s, v25.4h, v6.4h\n"
+    "smlal2 v21.4s, v25.8h, v6.8h\n"
+    "add x24, x24, x15\n"
+    "tbz x8, #2, 57f\n"
     "ld1 { v29.s }[0], [x24], #0x4\n"
-    "tbz x4, #1, 56f\n"
+    "tbz x8, #1, 56f\n"
     "ld1 { v29.h }[2], [x24], #0x2\n"
-    "tbz x4, #0, 59f\n"
+    "tbz x8, #0, 59f\n"
     "ld1 { v29.b }[6], [x24]\n"
     "b 59f\n"
     "56:"  // Oddments: Load (2, 4): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 59f\n"
+    "tbz x8, #0, 59f\n"
     "ld1 { v29.b }[4], [x24]\n"
     "b 59f\n"
     "57:"  // Oddments: Load (2, 4): Bit 2: Unset
-    "tbz x4, #1, 58f\n"
+    "tbz x8, #1, 58f\n"
     "ld1 { v29.h }[0], [x24], #0x2\n"
-    "tbz x4, #0, 59f\n"
+    "tbz x8, #0, 59f\n"
     "ld1 { v29.b }[2], [x24]\n"
     "b 59f\n"
     "58:"  // Oddments: Load (2, 4): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 59f\n"
+    "tbz x8, #0, 59f\n"
     "ld1 { v29.b }[0], [x24]\n"
     "59:"  // Oddments: Load (2, 4): Bit 2: End
-    "ldr x23, [x8, #0xa0]\n"
     "ssubl v29.8h, v29.8b, v12.8b\n"
-    "smlal v20.4s, v29.4h, v8.4h\n"
-    "smlal2 v23.4s, v29.8h, v8.8h\n"
-    "add x23, x23, x5\n"
-    "smlal v17.4s, v29.4h, v2.4h\n"
+    "ldr x19, [x12, #0xa0]\n"
+    "smlal v9.4s, v29.4h, v8.4h\n"
+    "smlal2 v16.4s, v29.8h, v8.8h\n"
+    "smlal v23.4s, v29.4h, v2.4h\n"
     "smlal2 v18.4s, v29.8h, v2.8h\n"
-    "tbz x4, #2, 61f\n"
-    "ld1 { v27.s }[0], [x23], #0x4\n"
-    "tbz x4, #1, 60f\n"
-    "ld1 { v27.h }[2], [x23], #0x2\n"
-    "tbz x4, #0, 63f\n"
-    "ld1 { v27.b }[6], [x23]\n"
+    "add x19, x19, x15\n"
+    "tbz x8, #2, 61f\n"
+    "ld1 { v27.s }[0], [x19], #0x4\n"
+    "tbz x8, #1, 60f\n"
+    "ld1 { v27.h }[2], [x19], #0x2\n"
+    "tbz x8, #0, 63f\n"
+    "ld1 { v27.b }[6], [x19]\n"
     "b 63f\n"
     "60:"  // Oddments: Load (4, 1): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 63f\n"
-    "ld1 { v27.b }[4], [x23]\n"
+    "tbz x8, #0, 63f\n"
+    "ld1 { v27.b }[4], [x19]\n"
     "b 63f\n"
     "61:"  // Oddments: Load (4, 1): Bit 2: Unset
-    "tbz x4, #1, 62f\n"
-    "ld1 { v27.h }[0], [x23], #0x2\n"
-    "tbz x4, #0, 63f\n"
-    "ld1 { v27.b }[2], [x23]\n"
+    "tbz x8, #1, 62f\n"
+    "ld1 { v27.h }[0], [x19], #0x2\n"
+    "tbz x8, #0, 63f\n"
+    "ld1 { v27.b }[2], [x19]\n"
     "b 63f\n"
     "62:"  // Oddments: Load (4, 1): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 63f\n"
-    "ld1 { v27.b }[0], [x23]\n"
+    "tbz x8, #0, 63f\n"
+    "ld1 { v27.b }[0], [x19]\n"
     "63:"  // Oddments: Load (4, 1): Bit 2: End
-    "ldr x22, [x8, #0xa8]\n"
     "ssubl v27.8h, v27.8b, v12.8b\n"
-    "smlal v16.4s, v27.4h, v7.4h\n"
-    "smlal2 v22.4s, v27.8h, v7.8h\n"
-    "add x22, x22, x5\n"
-    "tbz x4, #2, 65f\n"
+    "ldr x22, [x12, #0xa8]\n"
+    "smlal v22.4s, v27.4h, v7.4h\n"
+    "smlal2 v21.4s, v27.8h, v7.8h\n"
+    "add x22, x22, x15\n"
+    "tbz x8, #2, 65f\n"
     "ld1 { v24.s }[0], [x22], #0x4\n"
-    "tbz x4, #1, 64f\n"
+    "tbz x8, #1, 64f\n"
     "ld1 { v24.h }[2], [x22], #0x2\n"
-    "tbz x4, #0, 67f\n"
+    "tbz x8, #0, 67f\n"
     "ld1 { v24.b }[6], [x22]\n"
     "b 67f\n"
     "64:"  // Oddments: Load (3, 2): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 67f\n"
+    "tbz x8, #0, 67f\n"
     "ld1 { v24.b }[4], [x22]\n"
     "b 67f\n"
     "65:"  // Oddments: Load (3, 2): Bit 2: Unset
-    "tbz x4, #1, 66f\n"
+    "tbz x8, #1, 66f\n"
     "ld1 { v24.h }[0], [x22], #0x2\n"
-    "tbz x4, #0, 67f\n"
+    "tbz x8, #0, 67f\n"
     "ld1 { v24.b }[2], [x22]\n"
     "b 67f\n"
     "66:"  // Oddments: Load (3, 2): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 67f\n"
+    "tbz x8, #0, 67f\n"
     "ld1 { v24.b }[0], [x22]\n"
     "67:"  // Oddments: Load (3, 2): Bit 2: End
-    "ldr x21, [x8, #0xb0]\n"
     "ssubl v24.8h, v24.8b, v12.8b\n"
-    "smlal v16.4s, v24.4h, v5.4h\n"
-    "smlal2 v22.4s, v24.8h, v5.8h\n"
-    "add x21, x21, x5\n"
-    "smlal v17.4s, v24.4h, v3.4h\n"
+    "ldr x21, [x12, #0xb0]\n"
+    "smlal v22.4s, v24.4h, v5.4h\n"
+    "smlal2 v21.4s, v24.8h, v5.8h\n"
+    "smlal v23.4s, v24.4h, v3.4h\n"
     "smlal2 v18.4s, v24.8h, v3.8h\n"
-    "tbz x4, #2, 69f\n"
+    "add x21, x21, x15\n"
+    "tbz x8, #2, 69f\n"
     "ld1 { v26.s }[0], [x21], #0x4\n"
-    "tbz x4, #1, 68f\n"
+    "tbz x8, #1, 68f\n"
     "ld1 { v26.h }[2], [x21], #0x2\n"
-    "tbz x4, #0, 71f\n"
+    "tbz x8, #0, 71f\n"
     "ld1 { v26.b }[6], [x21]\n"
     "b 71f\n"
     "68:"  // Oddments: Load (4, 3): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 71f\n"
+    "tbz x8, #0, 71f\n"
     "ld1 { v26.b }[4], [x21]\n"
     "b 71f\n"
     "69:"  // Oddments: Load (4, 3): Bit 2: Unset
-    "tbz x4, #1, 70f\n"
+    "tbz x8, #1, 70f\n"
     "ld1 { v26.h }[0], [x21], #0x2\n"
-    "tbz x4, #0, 71f\n"
+    "tbz x8, #0, 71f\n"
     "ld1 { v26.b }[2], [x21]\n"
     "b 71f\n"
     "70:"  // Oddments: Load (4, 3): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 71f\n"
+    "tbz x8, #0, 71f\n"
     "ld1 { v26.b }[0], [x21]\n"
     "71:"  // Oddments: Load (4, 3): Bit 2: End
-    "ldr x20, [x8, #0xb8]\n"
     "ssubl v26.8h, v26.8b, v12.8b\n"
-    "smlal v17.4s, v26.4h, v7.4h\n"
+    "ldr x20, [x12, #0xb8]\n"
+    "smlal v23.4s, v26.4h, v7.4h\n"
     "smlal2 v18.4s, v26.8h, v7.8h\n"
-    "add x20, x20, x5\n"
-    "tbz x4, #2, 73f\n"
+    "add x20, x20, x15\n"
+    "tbz x8, #2, 73f\n"
     "ld1 { v25.s }[0], [x20], #0x4\n"
-    "tbz x4, #1, 72f\n"
+    "tbz x8, #1, 72f\n"
     "ld1 { v25.h }[2], [x20], #0x2\n"
-    "tbz x4, #0, 75f\n"
+    "tbz x8, #0, 75f\n"
     "ld1 { v25.b }[6], [x20]\n"
     "b 75f\n"
     "72:"  // Oddments: Load (4, 2): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 75f\n"
+    "tbz x8, #0, 75f\n"
     "ld1 { v25.b }[4], [x20]\n"
     "b 75f\n"
     "73:"  // Oddments: Load (4, 2): Bit 2: Unset
-    "tbz x4, #1, 74f\n"
+    "tbz x8, #1, 74f\n"
     "ld1 { v25.h }[0], [x20], #0x2\n"
-    "tbz x4, #0, 75f\n"
+    "tbz x8, #0, 75f\n"
     "ld1 { v25.b }[2], [x20]\n"
     "b 75f\n"
     "74:"  // Oddments: Load (4, 2): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 75f\n"
+    "tbz x8, #0, 75f\n"
     "ld1 { v25.b }[0], [x20]\n"
     "75:"  // Oddments: Load (4, 2): Bit 2: End
-    "ldr x19, [x8, #0xc0]\n"
     "ssubl v25.8h, v25.8b, v12.8b\n"
-    "smlal v16.4s, v25.4h, v8.4h\n"
-    "smlal2 v22.4s, v25.8h, v8.8h\n"
-    "add x19, x19, x5\n"
-    "smlal v17.4s, v25.4h, v6.4h\n"
+    "ldr x19, [x12, #0xc0]\n"
+    "smlal v22.4s, v25.4h, v8.4h\n"
+    "smlal2 v21.4s, v25.8h, v8.8h\n"
+    "smlal v23.4s, v25.4h, v6.4h\n"
     "smlal2 v18.4s, v25.8h, v6.8h\n"
-    "tbz x4, #2, 77f\n"
+    "add x19, x19, x15\n"
+    "tbz x8, #2, 77f\n"
     "ld1 { v29.s }[0], [x19], #0x4\n"
-    "tbz x4, #1, 76f\n"
+    "tbz x8, #1, 76f\n"
     "ld1 { v29.h }[2], [x19], #0x2\n"
-    "tbz x4, #0, 79f\n"
+    "tbz x8, #0, 79f\n"
     "ld1 { v29.b }[6], [x19]\n"
     "b 79f\n"
     "76:"  // Oddments: Load (4, 4): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 79f\n"
+    "tbz x8, #0, 79f\n"
     "ld1 { v29.b }[4], [x19]\n"
     "b 79f\n"
     "77:"  // Oddments: Load (4, 4): Bit 2: Unset
-    "tbz x4, #1, 78f\n"
+    "tbz x8, #1, 78f\n"
     "ld1 { v29.h }[0], [x19], #0x2\n"
-    "tbz x4, #0, 79f\n"
+    "tbz x8, #0, 79f\n"
     "ld1 { v29.b }[2], [x19]\n"
     "b 79f\n"
     "78:"  // Oddments: Load (4, 4): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 79f\n"
+    "tbz x8, #0, 79f\n"
     "ld1 { v29.b }[0], [x19]\n"
     "79:"  // Oddments: Load (4, 4): Bit 2: End
     "ssubl v29.8h, v29.8b, v12.8b\n"
-    "smlal v17.4s, v29.4h, v8.4h\n"
+    "smlal v23.4s, v29.4h, v8.4h\n"
     "smlal2 v18.4s, v29.8h, v8.8h\n"
-    "tbz x4, #2, 81f\n"
-    "ld1 { v21.4s }, [x17], #0x10\n"
-    "ld1 { v30.4s }, [x15], #0x10\n"
-    "tbz x4, #1, 80f\n"
-    "ld1 { v31.d }[0], [x17], #0x8\n"
-    "ld1 { v9.d }[0], [x15], #0x8\n"
-    "tbz x4, #0, 83f\n"
-    "ld1 { v31.s }[2], [x17]\n"
-    "ld1 { v9.s }[2], [x15]\n"
+    "tbz x8, #2, 81f\n"
+    "ld1 { v19.4s }, [x13], #0x10\n"
+    "ld1 { v0.4s }, [x11], #0x10\n"
+    "tbz x8, #1, 80f\n"
+    "ld1 { v4.d }[0], [x13], #0x8\n"
+    "ld1 { v31.d }[0], [x11], #0x8\n"
+    "tbz x8, #0, 83f\n"
+    "ld1 { v4.s }[2], [x13]\n"
+    "ld1 { v31.s }[2], [x11]\n"
     "b 83f\n"
     "80:"  // Oddments: Load requant params: Bit 2: Bit 1: Unset
-    "tbz x4, #0, 83f\n"
-    "ld1 { v31.s }[0], [x17]\n"
-    "ld1 { v9.s }[0], [x15]\n"
+    "tbz x8, #0, 83f\n"
+    "ld1 { v4.s }[0], [x13]\n"
+    "ld1 { v31.s }[0], [x11]\n"
     "b 83f\n"
     "81:"  // Oddments: Load requant params: Bit 2: Unset
-    "tbz x4, #1, 82f\n"
-    "ld1 { v21.d }[0], [x17], #0x8\n"
-    "ld1 { v30.d }[0], [x15], #0x8\n"
-    "tbz x4, #0, 83f\n"
-    "ld1 { v21.s }[2], [x17]\n"
-    "ld1 { v30.s }[2], [x15]\n"
+    "tbz x8, #1, 82f\n"
+    "ld1 { v19.d }[0], [x13], #0x8\n"
+    "ld1 { v0.d }[0], [x11], #0x8\n"
+    "tbz x8, #0, 83f\n"
+    "ld1 { v19.s }[2], [x13]\n"
+    "ld1 { v0.s }[2], [x11]\n"
     "b 83f\n"
     "82:"  // Oddments: Load requant params: Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 83f\n"
-    "ld1 { v21.s }[0], [x17]\n"
-    "ld1 { v30.s }[0], [x15]\n"
+    "tbz x8, #0, 83f\n"
+    "ld1 { v19.s }[0], [x13]\n"
+    "ld1 { v0.s }[0], [x11]\n"
     "83:"  // Oddments: Load requant params: Bit 2: End
-    "sqrdmulh v15.4s, v15.4s, v21.4s\n"
-    "add x14, x14, x7\n"
-    "sqrdmulh v10.4s, v10.4s, v31.4s\n"
-    "add x13, x13, x7\n"
-    "sqrdmulh v20.4s, v20.4s, v21.4s\n"
-    "add x12, x12, x7\n"
-    "sqrdmulh v23.4s, v23.4s, v31.4s\n"
-    "add x11, x11, x7\n"
-    "sqrdmulh v16.4s, v16.4s, v21.4s\n"
-    "and v26.16b, v15.16b, v30.16b\n"
-    "sshr v26.4s, v26.4s, #0x1f\n"
-    "and v8.16b, v10.16b, v9.16b\n"
-    "and v4.16b, v20.16b, v30.16b\n"
-    "sshr v8.4s, v8.4s, #0x1f\n"
-    "and v2.16b, v23.16b, v9.16b\n"
-    "and v1.16b, v16.16b, v30.16b\n"
+    "sqdmulh v15.4s, v15.4s, v19.4s\n"
+    "sqdmulh v9.4s, v9.4s, v19.4s\n"
+    "add x10, x10, x14\n"
+    "add x9, x9, x14\n"
+    "sqdmulh v22.4s, v22.4s, v19.4s\n"
+    "sqdmulh v23.4s, v23.4s, v19.4s\n"
+    "add x28, x28, x14\n"
+    "add x27, x27, x14\n"
+    "and v30.16b, v15.16b, v0.16b\n"
+    "sqdmulh v10.4s, v10.4s, v4.4s\n"
+    "and v28.16b, v9.16b, v0.16b\n"
+    "sqdmulh v16.4s, v16.4s, v4.4s\n"
+    "and v29.16b, v22.16b, v0.16b\n"
+    "sqdmulh v21.4s, v21.4s, v4.4s\n"
+    "and v20.16b, v23.16b, v0.16b\n"
+    "sqdmulh v18.4s, v18.4s, v4.4s\n"
+    "sshr v30.4s, v30.4s, #0x1f\n"
+    "and v19.16b, v10.16b, v31.16b\n"
+    "sshr v28.4s, v28.4s, #0x1f\n"
+    "and v4.16b, v16.16b, v31.16b\n"
+    "sshr v29.4s, v29.4s, #0x1f\n"
+    "and v5.16b, v21.16b, v31.16b\n"
+    "sshr v20.4s, v20.4s, #0x1f\n"
+    "and v26.16b, v18.16b, v31.16b\n"
+    "sqadd v15.4s, v15.4s, v30.4s\n"
+    "sshr v19.4s, v19.4s, #0x1f\n"
+    "sqadd v9.4s, v9.4s, v28.4s\n"
     "sshr v4.4s, v4.4s, #0x1f\n"
-    "sqrdmulh v22.4s, v22.4s, v31.4s\n"
-    "sshr v2.4s, v2.4s, #0x1f\n"
-    "sqadd v15.4s, v15.4s, v26.4s\n"
-    "sqrdmulh v17.4s, v17.4s, v21.4s\n"
-    "sshr v1.4s, v1.4s, #0x1f\n"
-    "sqrdmulh v18.4s, v18.4s, v31.4s\n"
-    "sqadd v10.4s, v10.4s, v8.4s\n"
-    "sqadd v20.4s, v20.4s, v4.4s\n"
-    "srshl v15.4s, v15.4s, v30.4s\n"
-    "sqadd v23.4s, v23.4s, v2.4s\n"
-    "srshl v10.4s, v10.4s, v9.4s\n"
-    "srshl v20.4s, v20.4s, v30.4s\n"
-    "add v15.4s, v15.4s, v11.4s\n"
-    "srshl v23.4s, v23.4s, v9.4s\n"
-    "add v10.4s, v10.4s, v11.4s\n"
-    "smin v15.4s, v15.4s, v14.4s\n"
-    "add v20.4s, v20.4s, v11.4s\n"
-    "smin v10.4s, v10.4s, v14.4s\n"
-    "smax v15.4s, v15.4s, v19.4s\n"
-    "smin v20.4s, v20.4s, v14.4s\n"
-    "smax v10.4s, v10.4s, v19.4s\n"
-    "add v23.4s, v23.4s, v11.4s\n"
-    "smax v20.4s, v20.4s, v19.4s\n"
-    "uzp1 v15.16b, v15.16b, v10.16b\n"
-    "smin v23.4s, v23.4s, v14.4s\n"
+    "sqadd v22.4s, v22.4s, v29.4s\n"
+    "sshr v5.4s, v5.4s, #0x1f\n"
+    "sqadd v23.4s, v23.4s, v20.4s\n"
+    "sshr v26.4s, v26.4s, #0x1f\n"
+    "srshl v15.4s, v15.4s, v0.4s\n"
+    "sqadd v10.4s, v10.4s, v19.4s\n"
+    "srshl v9.4s, v9.4s, v0.4s\n"
+    "sqadd v16.4s, v16.4s, v4.4s\n"
+    "srshl v22.4s, v22.4s, v0.4s\n"
+    "sqadd v21.4s, v21.4s, v5.4s\n"
+    "srshl v23.4s, v23.4s, v0.4s\n"
+    "sqadd v18.4s, v18.4s, v26.4s\n"
+    "srshl v10.4s, v10.4s, v31.4s\n"
+    "sqxtn v15.4h, v15.4s\n"
+    "srshl v16.4s, v16.4s, v31.4s\n"
+    "sqxtn v9.4h, v9.4s\n"
+    "srshl v21.4s, v21.4s, v31.4s\n"
+    "sqxtn v22.4h, v22.4s\n"
+    "srshl v18.4s, v18.4s, v31.4s\n"
+    "sqxtn v23.4h, v23.4s\n"
+    "sqxtn2 v15.8h, v10.4s\n"
+    "sqxtn2 v9.8h, v16.4s\n"
+    "sqxtn2 v22.8h, v21.4s\n"
+    "sqxtn2 v23.8h, v18.4s\n"
+    "sqadd v15.8h, v15.8h, v11.8h\n"
+    "sqadd v9.8h, v9.8h, v11.8h\n"
+    "sqadd v22.8h, v22.8h, v11.8h\n"
+    "sqadd v23.8h, v23.8h, v11.8h\n"
+    "smax v15.8h, v15.8h, v17.8h\n"
+    "smax v9.8h, v9.8h, v17.8h\n"
+    "smax v22.8h, v22.8h, v17.8h\n"
+    "smax v23.8h, v23.8h, v17.8h\n"
+    "smin v15.8h, v15.8h, v14.8h\n"
+    "smin v9.8h, v9.8h, v14.8h\n"
+    "smin v22.8h, v22.8h, v14.8h\n"
+    "smin v23.8h, v23.8h, v14.8h\n"
     "uzp1 v15.16b, v15.16b, v15.16b\n"
-    "sqadd v16.4s, v16.4s, v1.4s\n"
-    "smax v23.4s, v23.4s, v19.4s\n"
-    "and v24.16b, v22.16b, v9.16b\n"
-    "sshr v24.4s, v24.4s, #0x1f\n"
-    "uzp1 v20.16b, v20.16b, v23.16b\n"
-    "srshl v16.4s, v16.4s, v30.4s\n"
-    "and v2.16b, v17.16b, v30.16b\n"
-    "sshr v2.4s, v2.4s, #0x1f\n"
-    "uzp1 v20.16b, v20.16b, v20.16b\n"
-    "add v16.4s, v16.4s, v11.4s\n"
-    "sqadd v22.4s, v22.4s, v24.4s\n"
-    "and v31.16b, v18.16b, v9.16b\n"
-    "sshr v31.4s, v31.4s, #0x1f\n"
-    "smin v16.4s, v16.4s, v14.4s\n"
-    "srshl v22.4s, v22.4s, v9.4s\n"
-    "sqadd v17.4s, v17.4s, v2.4s\n"
-    "smax v16.4s, v16.4s, v19.4s\n"
-    "add v22.4s, v22.4s, v11.4s\n"
-    "srshl v17.4s, v17.4s, v30.4s\n"
-    "sqadd v18.4s, v18.4s, v31.4s\n"
-    "smin v22.4s, v22.4s, v14.4s\n"
-    "add v17.4s, v17.4s, v11.4s\n"
-    "srshl v18.4s, v18.4s, v9.4s\n"
-    "smax v22.4s, v22.4s, v19.4s\n"
-    "smin v17.4s, v17.4s, v14.4s\n"
-    "uzp1 v16.16b, v16.16b, v22.16b\n"
-    "add v18.4s, v18.4s, v11.4s\n"
-    "uzp1 v16.16b, v16.16b, v16.16b\n"
-    "smax v17.4s, v17.4s, v19.4s\n"
-    "smin v18.4s, v18.4s, v14.4s\n"
-    "smax v18.4s, v18.4s, v19.4s\n"
-    "uzp1 v17.16b, v17.16b, v18.16b\n"
-    "uzp1 v17.16b, v17.16b, v17.16b\n"
-    "tbz x4, #2, 85f\n"
-    "st1 { v15.s }[0], [x14], #0x4\n"
-    "st1 { v20.s }[0], [x13], #0x4\n"
-    "st1 { v16.s }[0], [x12], #0x4\n"
-    "st1 { v17.s }[0], [x11], #0x4\n"
-    "tbz x4, #1, 84f\n"
-    "st1 { v15.h }[2], [x14], #0x2\n"
-    "st1 { v20.h }[2], [x13], #0x2\n"
-    "st1 { v16.h }[2], [x12], #0x2\n"
-    "st1 { v17.h }[2], [x11], #0x2\n"
-    "tbz x4, #0, 87f\n"
-    "st1 { v15.b }[6], [x14], #0x1\n"
-    "st1 { v20.b }[6], [x13], #0x1\n"
-    "st1 { v16.b }[6], [x12], #0x1\n"
-    "st1 { v17.b }[6], [x11], #0x1\n"
+    "uzp1 v9.16b, v9.16b, v9.16b\n"
+    "uzp1 v22.16b, v22.16b, v22.16b\n"
+    "uzp1 v23.16b, v23.16b, v23.16b\n"
+    "tbz x8, #2, 85f\n"
+    "st1 { v15.s }[0], [x10], #0x4\n"
+    "st1 { v9.s }[0], [x9], #0x4\n"
+    "st1 { v22.s }[0], [x28], #0x4\n"
+    "st1 { v23.s }[0], [x27], #0x4\n"
+    "tbz x8, #1, 84f\n"
+    "st1 { v15.h }[2], [x10], #0x2\n"
+    "st1 { v9.h }[2], [x9], #0x2\n"
+    "st1 { v22.h }[2], [x28], #0x2\n"
+    "st1 { v23.h }[2], [x27], #0x2\n"
+    "tbz x8, #0, 87f\n"
+    "st1 { v15.b }[6], [x10], #0x1\n"
+    "st1 { v9.b }[6], [x9], #0x1\n"
+    "st1 { v22.b }[6], [x28], #0x1\n"
+    "st1 { v23.b }[6], [x27], #0x1\n"
     "b 87f\n"
     "84:"  // Oddments: Bit 2: Bit 1: Unset
-    "tbz x4, #0, 87f\n"
-    "st1 { v15.b }[4], [x14], #0x1\n"
-    "st1 { v20.b }[4], [x13], #0x1\n"
-    "st1 { v16.b }[4], [x12], #0x1\n"
-    "st1 { v17.b }[4], [x11], #0x1\n"
+    "tbz x8, #0, 87f\n"
+    "st1 { v15.b }[4], [x10], #0x1\n"
+    "st1 { v9.b }[4], [x9], #0x1\n"
+    "st1 { v22.b }[4], [x28], #0x1\n"
+    "st1 { v23.b }[4], [x27], #0x1\n"
     "b 87f\n"
     "85:"  // Oddments: Bit 2: Unset
-    "tbz x4, #1, 86f\n"
-    "st1 { v15.h }[0], [x14], #0x2\n"
-    "st1 { v20.h }[0], [x13], #0x2\n"
-    "st1 { v16.h }[0], [x12], #0x2\n"
-    "st1 { v17.h }[0], [x11], #0x2\n"
-    "tbz x4, #0, 87f\n"
-    "st1 { v15.b }[2], [x14], #0x1\n"
-    "st1 { v20.b }[2], [x13], #0x1\n"
-    "st1 { v16.b }[2], [x12], #0x1\n"
-    "st1 { v17.b }[2], [x11], #0x1\n"
+    "tbz x8, #1, 86f\n"
+    "st1 { v15.h }[0], [x10], #0x2\n"
+    "st1 { v9.h }[0], [x9], #0x2\n"
+    "st1 { v22.h }[0], [x28], #0x2\n"
+    "st1 { v23.h }[0], [x27], #0x2\n"
+    "tbz x8, #0, 87f\n"
+    "st1 { v15.b }[2], [x10], #0x1\n"
+    "st1 { v9.b }[2], [x9], #0x1\n"
+    "st1 { v22.b }[2], [x28], #0x1\n"
+    "st1 { v23.b }[2], [x27], #0x1\n"
     "b 87f\n"
     "86:"  // Oddments: Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 87f\n"
-    "st1 { v15.b }[0], [x14], #0x1\n"
-    "st1 { v20.b }[0], [x13], #0x1\n"
-    "st1 { v16.b }[0], [x12], #0x1\n"
-    "st1 { v17.b }[0], [x11], #0x1\n"
+    "tbz x8, #0, 87f\n"
+    "st1 { v15.b }[0], [x10], #0x1\n"
+    "st1 { v9.b }[0], [x9], #0x1\n"
+    "st1 { v22.b }[0], [x28], #0x1\n"
+    "st1 { v23.b }[0], [x27], #0x1\n"
     "87:"  // Oddments: Bit 2: End
-
     "88:"  // End
-
     :
     : [offsetof_Params_bias] "I" (offsetof(Params, bias)), [offsetof_Params_inptrs] "I" (offsetof(Params, inptrs)), [offsetof_Params_n_channels] "I" (offsetof(Params, n_channels)), [offsetof_Params_outptrs] "I" (offsetof(Params, outptrs)), [offsetof_Params_requant] "I" (offsetof(Params, requant)), [offsetof_Params_requant_muls] "I" (offsetof(Params, requant_muls)), [offsetof_Params_requant_shifts] "I" (offsetof(Params, requant_shifts)), [offsetof_Params_weights] "I" (offsetof(Params, weights)), [offsetof_Requantize32_a_offset] "I" (offsetof(arm_gemm::Requantize32, a_offset)), [offsetof_Requantize32_b_offset] "I" (offsetof(arm_gemm::Requantize32, b_offset)), [offsetof_Requantize32_c_offset] "I" (offsetof(arm_gemm::Requantize32, c_offset)), [offsetof_Requantize32_maxval] "I" (offsetof(arm_gemm::Requantize32, maxval)), [offsetof_Requantize32_minval] "I" (offsetof(arm_gemm::Requantize32, minval)), [params] "r" (&params)
-    : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17", "x19", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"
+    : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x8", "x9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17", "x19", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"
   );
 }
 
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_s8q_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_s8q_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp
index a998fa1..52031e1 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_s8q_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_s8q_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -36,37 +36,24 @@
 
 void a64_s8q_nhwc_5x5_s1_output2x2_mla_depthfirst_impl(unsigned int, const int8_t *const *, const int8_t *, const int32_t *, const arm_gemm::Requantize32 &, const int32_t *, const int32_t *, int8_t *const *);
 
-struct a64_s8q_nhwc_5x5_s1_output2x2_mla_depthfirst
+class a64_s8q_nhwc_5x5_s1_output2x2_mla_depthfirst : public DepthwiseDepthfirstStrategy<int8_t, int8_t, int8_t, int32_t>
 {
-  typedef int32_t bias_type;
-  typedef int8_t input_type;
-  typedef int8_t weight_type;
-  typedef int8_t return_type;
+  using Parent = DepthwiseDepthfirstStrategy<int8_t, int8_t, int8_t, int32_t>;
 
-  constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::None;
-
-  typedef void (*kern_type)(unsigned int, const int8_t *const *, const int8_t *, const int32_t *, const arm_gemm::Requantize32 &, const int32_t *, const int32_t *, int8_t *const *);
-  typedef void (*parameter_packing_fn)(unsigned int, void *, const int8_t *, size_t, size_t);
-  typedef size_t (*parameter_sizing_fn)(const DepthwiseArgs &);
-
+  public:
   constexpr static unsigned int kernel_rows = 5;
   constexpr static unsigned int kernel_cols = 5;
 
   constexpr static unsigned int stride_rows = 1;
   constexpr static unsigned int stride_cols = 1;
 
-  constexpr static unsigned int output_rows = 2;
-  constexpr static unsigned int output_cols = 2;
+  a64_s8q_nhwc_5x5_s1_output2x2_mla_depthfirst(const CPUInfo *) : Parent(2, 2, 5, 5, 1, 1) {}
 
-  constexpr static unsigned int input_rows = 6;
-  constexpr static unsigned int input_cols = 6;
+  arm_gemm::VLType get_vl_type(void) const override { return arm_gemm::VLType::None; }
 
-  constexpr static parameter_packing_fn pack_parameters = interleave_a64_s8q_5x5_mla::pack_parameters;
-  constexpr static parameter_sizing_fn get_packed_size = interleave_a64_s8q_5x5_mla::get_packed_size;
-
-  kern_type kernel = a64_s8q_nhwc_5x5_s1_output2x2_mla_depthfirst_impl;
-
-  a64_s8q_nhwc_5x5_s1_output2x2_mla_depthfirst(const CPUInfo *) {}
+  Parent::KernelType kernel = a64_s8q_nhwc_5x5_s1_output2x2_mla_depthfirst_impl;
+  Parent::KernelType get_kernel(void) const override { return kernel; }
+  unsigned int get_accumulator_depth_vl(void) const override { return 2; }
 };
 
 }  // namespace depthwise
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_s8q_nhwc_5x5_s1_output2x2_mla_depthfirst/generic.cpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_s8q_nhwc_5x5_s1_output2x2_mla_depthfirst/generic.cpp
index ab64f53..bd71b65 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_s8q_nhwc_5x5_s1_output2x2_mla_depthfirst/generic.cpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_s8q_nhwc_5x5_s1_output2x2_mla_depthfirst/generic.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -46,7 +46,7 @@
   struct Params
   {
     long unsigned int n_channels;
-    const int8_t *weights;
+    const void *weights;
     const int32_t *bias;
     const arm_gemm::Requantize32 *requant;
     const int32_t *const requant_muls;
@@ -57,7 +57,7 @@
     Params(
       long unsigned int n_channels,
       const int8_t *const *inptrs_raw,
-      const int8_t *const weights,
+      const void *const weights,
       const int32_t *const bias,
       const arm_gemm::Requantize32 &qp,
       const int32_t *const requant_muls,
@@ -111,2096 +111,2070 @@
                       requant_muls, requant_shifts, outptrs);
 
   __asm__ __volatile__(
-    "ldr x4, [%x[params], %[offsetof_Params_n_channels]]\n"
-    "mov x10, #0x0\n"
-    "ldr x3, [%x[params], %[offsetof_Params_weights]]\n"
-    "mov x1, #0x0\n"
-    "ldr x22, [%x[params], %[offsetof_Params_requant]]\n"
-    "add x25, %x[params], %[offsetof_Params_inptrs]\n"
-    "ldr x2, [%x[params], %[offsetof_Params_requant_muls]]\n"
-    "lsr x19, x4, #0x3\n"
-    "ldr x5, [%x[params], %[offsetof_Params_requant_shifts]]\n"
-    "add x13, x22, %[offsetof_Requantize32_a_offset]\n"
-    "ldr x21, [%x[params], %[offsetof_Params_outptrs]]\n"
-    "add x20, x22, %[offsetof_Requantize32_b_offset]\n"
-    "ld1r { v7.16b }, [x13]\n"
-    "add x8, x22, %[offsetof_Requantize32_c_offset]\n"
-    "ld1r { v13.16b }, [x20]\n"
-    "add x20, x22, %[offsetof_Requantize32_minval]\n"
-    "ld1r { v19.4s }, [x8]\n"
-    "add x8, x22, %[offsetof_Requantize32_maxval]\n"
-    "ld1r { v16.4s }, [x20]\n"
-    "ld1r { v12.4s }, [x8]\n"
-    "ldp x17, x16, [x21, #0x0]\n"
-    "ldp x6, x8, [x21, #0x10]\n"
-    "cbz x19, 3f\n"
-    "subs x19, x19, #0x1\n"
-    "ldr x12, [%x[params], %[offsetof_Params_bias]]\n"
-    "ldr q15, [x12, #0x0]\n"
-    "mov v18.16b, v15.16b\n"
-    "ldr q20, [x12, #0x10]\n"
-    "add x12, x12, #0x20\n"
-    "mov v11.16b, v15.16b\n"
-    "str x12, [%x[params], %[offsetof_Params_bias]]\n"
+    "ldr x10, [%x[params], %[offsetof_Params_requant]]\n"
+    "ldr x0, [%x[params], %[offsetof_Params_n_channels]]\n"
+    "add x17, x10, %[offsetof_Requantize32_a_offset]\n"
+    "add x9, x10, %[offsetof_Requantize32_b_offset]\n"
+    "ldr x25, [%x[params], %[offsetof_Params_outptrs]]\n"
+    "add x4, x10, %[offsetof_Requantize32_c_offset]\n"
+    "add x14, x10, %[offsetof_Requantize32_minval]\n"
+    "ldr x23, [%x[params], %[offsetof_Params_weights]]\n"
+    "add x5, x10, %[offsetof_Requantize32_maxval]\n"
+    "ld1r { v9.16b }, [x17]\n"
+    "ld1r { v14.16b }, [x9]\n"
+    "lsr x3, x0, #0x3\n"
+    "ld1r { v18.8h }, [x4]\n"
+    "ld1r { v11.8h }, [x14]\n"
+    "mov x24, #0x0\n"
+    "mov x22, #0x0\n"
+    "ld1r { v13.8h }, [x5]\n"
+    "ldr x10, [%x[params], %[offsetof_Params_requant_muls]]\n"
+    "add x20, %x[params], %[offsetof_Params_inptrs]\n"
+    "ldr x1, [%x[params], %[offsetof_Params_requant_shifts]]\n"
+    "ldp x16, x8, [x25, #0x0]\n"
+    "ldp x4, x7, [x25, #0x10]\n"
+    "cbz x3, 3f\n"
+    "ldr x19, [%x[params], %[offsetof_Params_bias]]\n"
+    "ldr q15, [x19, #0x0]\n"
+    "subs x3, x3, #0x1\n"
+    "mov v17.16b, v15.16b\n"
+    "ldr q16, [x19, #0x10]\n"
+    "add x19, x19, #0x20\n"
+    "str x19, [%x[params], %[offsetof_Params_bias]]\n"
+    "ldr d0, [x23, #0x0]\n"
+    "ldr d1, [x23, #0x8]\n"
+    "ldr d2, [x23, #0x10]\n"
+    "mov v8.16b, v16.16b\n"
     "mov v10.16b, v15.16b\n"
-    "ldr d0, [x3, #0x0]\n"
-    "ssubl v0.8h, v0.8b, v13.8b\n"
-    "mov v5.16b, v20.16b\n"
-    "ldr d1, [x3, #0x8]\n"
-    "mov v8.16b, v20.16b\n"
-    "ldr d2, [x3, #0x10]\n"
-    "ssubl v1.8h, v1.8b, v13.8b\n"
-    "mov v9.16b, v20.16b\n"
-    "ldr d3, [x3, #0x18]\n"
-    "ldr d4, [x3, #0x20]\n"
-    "ssubl v2.8h, v2.8b, v13.8b\n"
-    "ldp x28, x27, [x25, #0x0]\n"
-    "ssubl v3.8h, v3.8b, v13.8b\n"
-    "ldp x26, x13, [x25, #0x10]\n"
-    "ssubl v4.8h, v4.8b, v13.8b\n"
-    "ldp x24, x23, [x25, #0x20]\n"
-    "ldp x22, x21, [x25, #0x30]\n"
-    "ldp x20, x0, [x25, #0x40]\n"
-    "ldr d31, [x28, x10]\n"
-    "ssubl v31.8h, v31.8b, v7.8b\n"
-    "ldr d30, [x27, x10]\n"
-    "ldr d29, [x26, x10]\n"
-    "ssubl v30.8h, v30.8b, v7.8b\n"
-    "ldr d28, [x13, x10]\n"
-    "ldr d27, [x24, x10]\n"
-    "ssubl v29.8h, v29.8b, v7.8b\n"
-    "ldr d23, [x23, x10]\n"
-    "ssubl v28.8h, v28.8b, v7.8b\n"
-    "ldr d25, [x22, x10]\n"
-    "ldr d24, [x21, x10]\n"
-    "ssubl v27.8h, v27.8b, v7.8b\n"
-    "ldr d26, [x20, x10]\n"
-    "ssubl v23.8h, v23.8b, v7.8b\n"
-    "ldr d22, [x0, x10]\n"
-    "ssubl v25.8h, v25.8b, v7.8b\n"
-    "ssubl v24.8h, v24.8b, v7.8b\n"
-    "ssubl v26.8h, v26.8b, v7.8b\n"
-    "ssubl v22.8h, v22.8b, v7.8b\n"
+    "ldr d3, [x23, #0x18]\n"
+    "ldr d4, [x23, #0x20]\n"
+    "mov v7.16b, v16.16b\n"
+    "mov v6.16b, v15.16b\n"
+    "ldp x28, x6, [x20, #0x0]\n"
+    "ldp x26, x25, [x20, #0x10]\n"
+    "mov v5.16b, v16.16b\n"
+    "ssubl v0.8h, v0.8b, v14.8b\n"
+    "ldp x5, x2, [x20, #0x20]\n"
+    "ldp x27, x21, [x20, #0x30]\n"
+    "ssubl v1.8h, v1.8b, v14.8b\n"
+    "ssubl v2.8h, v2.8b, v14.8b\n"
+    "ldp x12, x19, [x20, #0x40]\n"
+    "ldr d31, [x28, x24]\n"
+    "ssubl v3.8h, v3.8b, v14.8b\n"
+    "ssubl v4.8h, v4.8b, v14.8b\n"
+    "ldr d30, [x6, x24]\n"
+    "ldr d29, [x26, x24]\n"
+    "ssubl v31.8h, v31.8b, v9.8b\n"
+    "ssubl v30.8h, v30.8b, v9.8b\n"
+    "ldr d28, [x25, x24]\n"
+    "ldr d27, [x5, x24]\n"
+    "ssubl v29.8h, v29.8b, v9.8b\n"
+    "ssubl v28.8h, v28.8b, v9.8b\n"
+    "ldr d23, [x2, x24]\n"
+    "ldr d25, [x27, x24]\n"
+    "ssubl v27.8h, v27.8b, v9.8b\n"
+    "ssubl v23.8h, v23.8b, v9.8b\n"
+    "ldr d24, [x21, x24]\n"
+    "ldr d26, [x12, x24]\n"
+    "ssubl v25.8h, v25.8b, v9.8b\n"
+    "ssubl v24.8h, v24.8b, v9.8b\n"
+    "ldr d22, [x19, x24]\n"
+    "ssubl v26.8h, v26.8b, v9.8b\n"
+    "ssubl v22.8h, v22.8b, v9.8b\n"
     "beq 2f\n"
     "1:"  // Loop
     "smlal v15.4s, v31.4h, v0.4h\n"
-    "ldr x20, [x25, #0x50]\n"
-    "subs x19, x19, #0x1\n"
-    "smlal2 v20.4s, v31.8h, v0.8h\n"
-    "ldr x28, [x25, #0x58]\n"
-    "smlal v18.4s, v30.4h, v0.4h\n"
-    "ldr x0, [x25, #0x60]\n"
-    "smlal2 v5.4s, v30.8h, v0.8h\n"
-    "ldr d31, [x20, x10]\n"
-    "ssubl v31.8h, v31.8b, v7.8b\n"
-    "smlal v11.4s, v29.4h, v0.4h\n"
-    "ldr x7, [x25, #0x68]\n"
-    "smlal2 v8.4s, v29.8h, v0.8h\n"
-    "ldr x26, [x25, #0x70]\n"
-    "smlal v10.4s, v28.4h, v0.4h\n"
-    "ldr x23, [x25, #0x78]\n"
-    "smlal2 v9.4s, v28.8h, v0.8h\n"
-    "ldr d0, [x3, #0x28]\n"
-    "ssubl v0.8h, v0.8b, v13.8b\n"
+    "smlal2 v16.4s, v31.8h, v0.8h\n"
+    "ldr x19, [x20, #0x50]\n"
+    "ldr d31, [x19, x24]\n"
+    "smlal v17.4s, v30.4h, v0.4h\n"
+    "smlal v10.4s, v29.4h, v0.4h\n"
+    "ldr x15, [x20, #0x58]\n"
+    "ssubl v31.8h, v31.8b, v9.8b\n"
+    "smlal v6.4s, v28.4h, v0.4h\n"
+    "smlal2 v8.4s, v30.8h, v0.8h\n"
+    "ldr x19, [x20, #0x60]\n"
+    "ldr x27, [x20, #0x68]\n"
+    "smlal2 v7.4s, v29.8h, v0.8h\n"
     "smlal v15.4s, v30.4h, v1.4h\n"
-    "ldr x20, [x25, #0x80]\n"
-    "smlal2 v20.4s, v30.8h, v1.8h\n"
-    "ldr d30, [x28, x10]\n"
-    "ssubl v30.8h, v30.8b, v7.8b\n"
-    "smlal v18.4s, v27.4h, v1.4h\n"
-    "ldr x22, [x25, #0x88]\n"
-    "smlal2 v5.4s, v27.8h, v1.8h\n"
-    "ldr x13, [x25, #0x90]\n"
-    "smlal v11.4s, v28.4h, v1.4h\n"
-    "ldr x21, [x25, #0x98]\n"
-    "smlal2 v8.4s, v28.8h, v1.8h\n"
-    "ldr x14, [x25, #0xa0]\n"
-    "smlal v10.4s, v23.4h, v1.4h\n"
-    "ldr x11, [x25, #0xa8]\n"
-    "smlal2 v9.4s, v23.8h, v1.8h\n"
-    "ldr d1, [x3, #0x30]\n"
-    "ssubl v1.8h, v1.8b, v13.8b\n"
-    "smlal v15.4s, v27.4h, v2.4h\n"
-    "ldr x24, [x25, #0xb0]\n"
-    "smlal2 v20.4s, v27.8h, v2.8h\n"
-    "ldr d27, [x0, x10]\n"
-    "ssubl v27.8h, v27.8b, v7.8b\n"
-    "smlal v18.4s, v25.4h, v2.4h\n"
-    "ldr x0, [x25, #0xb8]\n"
-    "smlal2 v5.4s, v25.8h, v2.8h\n"
-    "ldr x15, [x25, #0xc0]\n"
-    "smlal v11.4s, v23.4h, v2.4h\n"
-    "ldr x9, [x25, #0xc8]\n"
-    "smlal2 v8.4s, v23.8h, v2.8h\n"
-    "ldr x27, [x25, #0xd0]\n"
-    "smlal v10.4s, v31.4h, v2.4h\n"
-    "ldr x28, [x25, #0xd8]\n"
-    "smlal2 v9.4s, v31.8h, v2.8h\n"
-    "ldr d2, [x3, #0x38]\n"
-    "ssubl v2.8h, v2.8b, v13.8b\n"
-    "smlal v15.4s, v25.4h, v3.4h\n"
-    "ldr q6, [x2, #0x0]\n"
-    "smlal2 v20.4s, v25.8h, v3.8h\n"
-    "ldr d25, [x7, x10]\n"
-    "ssubl v25.8h, v25.8b, v7.8b\n"
-    "smlal v18.4s, v24.4h, v3.4h\n"
-    "ldr x12, [x25, #0xe0]\n"
-    "smlal2 v5.4s, v24.8h, v3.8h\n"
-    "ldr q21, [x5, #0x0]\n"
-    "smlal v11.4s, v31.4h, v3.4h\n"
-    "ldr q17, [x2, #0x10]\n"
-    "add x2, x2, #0x20\n"
-    "smlal2 v8.4s, v31.8h, v3.8h\n"
-    "ldr q14, [x5, #0x10]\n"
-    "add x5, x5, #0x20\n"
-    "smlal v10.4s, v30.4h, v3.4h\n"
-    "smlal2 v9.4s, v30.8h, v3.8h\n"
-    "ldr d3, [x3, #0x40]\n"
-    "ssubl v3.8h, v3.8b, v13.8b\n"
-    "smlal v15.4s, v24.4h, v4.4h\n"
-    "smlal2 v20.4s, v24.8h, v4.8h\n"
-    "ldr d24, [x26, x10]\n"
-    "ssubl v24.8h, v24.8b, v7.8b\n"
-    "smlal v18.4s, v27.4h, v4.4h\n"
-    "ldr x7, [x25, #0xe8]\n"
-    "smlal2 v5.4s, v27.8h, v4.8h\n"
-    "ldr d27, [x23, x10]\n"
-    "ssubl v27.8h, v27.8b, v7.8b\n"
-    "smlal v11.4s, v30.4h, v4.4h\n"
-    "ldr x26, [x25, #0xf0]\n"
-    "smlal2 v8.4s, v30.8h, v4.8h\n"
-    "smlal v10.4s, v26.4h, v4.4h\n"
-    "smlal2 v9.4s, v26.8h, v4.8h\n"
-    "ldr d4, [x3, #0x48]\n"
-    "ssubl v4.8h, v4.8b, v13.8b\n"
-    "smlal v15.4s, v29.4h, v0.4h\n"
-    "smlal2 v20.4s, v29.8h, v0.8h\n"
-    "smlal v18.4s, v28.4h, v0.4h\n"
+    "ldr x5, [x20, #0x70]\n"
+    "ldr x11, [x20, #0x78]\n"
+    "smlal2 v16.4s, v30.8h, v1.8h\n"
     "smlal2 v5.4s, v28.8h, v0.8h\n"
-    "smlal v11.4s, v22.4h, v0.4h\n"
-    "smlal2 v8.4s, v22.8h, v0.8h\n"
-    "smlal v10.4s, v25.4h, v0.4h\n"
-    "smlal2 v9.4s, v25.8h, v0.8h\n"
-    "ldr d0, [x3, #0x50]\n"
-    "ssubl v0.8h, v0.8b, v13.8b\n"
-    "smlal v15.4s, v28.4h, v1.4h\n"
-    "smlal2 v20.4s, v28.8h, v1.8h\n"
-    "ldr d28, [x22, x10]\n"
-    "ssubl v28.8h, v28.8b, v7.8b\n"
-    "smlal v18.4s, v23.4h, v1.4h\n"
-    "ldr x23, [x25, #0xf8]\n"
+    "ldr d30, [x15, x24]\n"
+    "ssubl v30.8h, v30.8b, v9.8b\n"
+    "smlal v17.4s, v27.4h, v1.4h\n"
+    "smlal v10.4s, v28.4h, v1.4h\n"
+    "ldr d0, [x23, #0x28]\n"
+    "ssubl v0.8h, v0.8b, v14.8b\n"
+    "smlal v6.4s, v23.4h, v1.4h\n"
+    "smlal2 v8.4s, v27.8h, v1.8h\n"
+    "ldr x12, [x20, #0x80]\n"
+    "ldr x26, [x20, #0x88]\n"
+    "smlal2 v7.4s, v28.8h, v1.8h\n"
+    "smlal v15.4s, v27.4h, v2.4h\n"
+    "ldr x14, [x20, #0x90]\n"
+    "ldr x15, [x20, #0x98]\n"
+    "smlal2 v16.4s, v27.8h, v2.8h\n"
     "smlal2 v5.4s, v23.8h, v1.8h\n"
-    "smlal v11.4s, v25.4h, v1.4h\n"
-    "smlal2 v8.4s, v25.8h, v1.8h\n"
-    "smlal v10.4s, v24.4h, v1.4h\n"
-    "smlal2 v9.4s, v24.8h, v1.8h\n"
-    "ldr d1, [x3, #0x58]\n"
-    "ssubl v1.8h, v1.8b, v13.8b\n"
-    "smlal v15.4s, v23.4h, v2.4h\n"
-    "smlal2 v20.4s, v23.8h, v2.8h\n"
-    "ldr d23, [x20, x10]\n"
-    "ssubl v23.8h, v23.8b, v7.8b\n"
-    "smlal v18.4s, v31.4h, v2.4h\n"
-    "ldr x22, [x25, #0x100]\n"
-    "smlal2 v5.4s, v31.8h, v2.8h\n"
-    "smlal v11.4s, v24.4h, v2.4h\n"
-    "smlal2 v8.4s, v24.8h, v2.8h\n"
-    "smlal v10.4s, v27.4h, v2.4h\n"
-    "smlal2 v9.4s, v27.8h, v2.8h\n"
-    "ldr d2, [x3, #0x60]\n"
-    "ssubl v2.8h, v2.8b, v13.8b\n"
-    "smlal v15.4s, v31.4h, v3.4h\n"
-    "smlal2 v20.4s, v31.8h, v3.8h\n"
-    "ldr d31, [x13, x10]\n"
-    "ssubl v31.8h, v31.8b, v7.8b\n"
-    "smlal v18.4s, v30.4h, v3.4h\n"
-    "ldr x20, [x25, #0x108]\n"
-    "smlal2 v5.4s, v30.8h, v3.8h\n"
-    "smlal v11.4s, v27.4h, v3.4h\n"
-    "smlal2 v8.4s, v27.8h, v3.8h\n"
-    "smlal v10.4s, v23.4h, v3.4h\n"
-    "smlal2 v9.4s, v23.8h, v3.8h\n"
-    "ldr d3, [x3, #0x68]\n"
-    "ssubl v3.8h, v3.8b, v13.8b\n"
-    "smlal v15.4s, v30.4h, v4.4h\n"
-    "smlal2 v20.4s, v30.8h, v4.8h\n"
-    "ldr d30, [x21, x10]\n"
-    "ssubl v30.8h, v30.8b, v7.8b\n"
-    "smlal v18.4s, v26.4h, v4.4h\n"
-    "ldr x13, [x25, #0x110]\n"
-    "smlal2 v5.4s, v26.8h, v4.8h\n"
-    "ldr d26, [x14, x10]\n"
-    "ssubl v26.8h, v26.8b, v7.8b\n"
-    "smlal v11.4s, v23.4h, v4.4h\n"
-    "ldr x21, [x25, #0x118]\n"
-    "smlal2 v8.4s, v23.8h, v4.8h\n"
-    "smlal v10.4s, v28.4h, v4.4h\n"
-    "smlal2 v9.4s, v28.8h, v4.8h\n"
-    "ldr d4, [x3, #0x70]\n"
-    "ssubl v4.8h, v4.8b, v13.8b\n"
-    "smlal v15.4s, v22.4h, v0.4h\n"
-    "smlal2 v20.4s, v22.8h, v0.8h\n"
-    "ldr d22, [x0, x10]\n"
-    "ssubl v22.8h, v22.8b, v7.8b\n"
-    "smlal v18.4s, v25.4h, v0.4h\n"
-    "smlal2 v5.4s, v25.8h, v0.8h\n"
-    "smlal v11.4s, v31.4h, v0.4h\n"
-    "smlal2 v8.4s, v31.8h, v0.8h\n"
-    "smlal v10.4s, v30.4h, v0.4h\n"
-    "smlal2 v9.4s, v30.8h, v0.8h\n"
-    "ldr d0, [x3, #0x78]\n"
-    "ssubl v0.8h, v0.8b, v13.8b\n"
-    "smlal v15.4s, v25.4h, v1.4h\n"
-    "smlal2 v20.4s, v25.8h, v1.8h\n"
-    "ldr d25, [x11, x10]\n"
-    "ssubl v25.8h, v25.8b, v7.8b\n"
-    "smlal v18.4s, v24.4h, v1.4h\n"
-    "smlal2 v5.4s, v24.8h, v1.8h\n"
-    "smlal v11.4s, v30.4h, v1.4h\n"
-    "smlal2 v8.4s, v30.8h, v1.8h\n"
-    "smlal v10.4s, v26.4h, v1.4h\n"
-    "smlal2 v9.4s, v26.8h, v1.8h\n"
-    "ldr d1, [x3, #0x80]\n"
-    "ssubl v1.8h, v1.8b, v13.8b\n"
-    "smlal v15.4s, v24.4h, v2.4h\n"
-    "smlal2 v20.4s, v24.8h, v2.8h\n"
-    "ldr d24, [x24, x10]\n"
-    "ssubl v24.8h, v24.8b, v7.8b\n"
-    "smlal v18.4s, v27.4h, v2.4h\n"
-    "smlal2 v5.4s, v27.8h, v2.8h\n"
-    "smlal v11.4s, v26.4h, v2.4h\n"
-    "smlal2 v8.4s, v26.8h, v2.8h\n"
-    "smlal v10.4s, v25.4h, v2.4h\n"
-    "smlal2 v9.4s, v25.8h, v2.8h\n"
-    "ldr d2, [x3, #0x88]\n"
-    "ssubl v2.8h, v2.8b, v13.8b\n"
-    "smlal v15.4s, v27.4h, v3.4h\n"
-    "smlal2 v20.4s, v27.8h, v3.8h\n"
-    "ldr d27, [x15, x10]\n"
-    "ssubl v27.8h, v27.8b, v7.8b\n"
-    "smlal v18.4s, v23.4h, v3.4h\n"
-    "smlal2 v5.4s, v23.8h, v3.8h\n"
-    "smlal v11.4s, v25.4h, v3.4h\n"
-    "smlal2 v8.4s, v25.8h, v3.8h\n"
-    "smlal v10.4s, v24.4h, v3.4h\n"
-    "smlal2 v9.4s, v24.8h, v3.8h\n"
-    "ldr d3, [x3, #0x90]\n"
-    "ssubl v3.8h, v3.8b, v13.8b\n"
-    "smlal v15.4s, v23.4h, v4.4h\n"
-    "smlal2 v20.4s, v23.8h, v4.8h\n"
-    "ldr d23, [x9, x10]\n"
-    "ssubl v23.8h, v23.8b, v7.8b\n"
-    "smlal v18.4s, v28.4h, v4.4h\n"
-    "smlal2 v5.4s, v28.8h, v4.8h\n"
-    "ldr d28, [x12, x10]\n"
-    "ssubl v28.8h, v28.8b, v7.8b\n"
-    "smlal v11.4s, v24.4h, v4.4h\n"
-    "smlal2 v8.4s, v24.8h, v4.8h\n"
-    "smlal v10.4s, v22.4h, v4.4h\n"
-    "smlal2 v9.4s, v22.8h, v4.8h\n"
-    "ldr d4, [x3, #0x98]\n"
-    "ssubl v4.8h, v4.8b, v13.8b\n"
-    "smlal v15.4s, v31.4h, v0.4h\n"
-    "smlal2 v20.4s, v31.8h, v0.8h\n"
-    "ldr d31, [x27, x10]\n"
-    "ssubl v31.8h, v31.8b, v7.8b\n"
-    "smlal v18.4s, v30.4h, v0.4h\n"
-    "smlal2 v5.4s, v30.8h, v0.8h\n"
-    "smlal v11.4s, v27.4h, v0.4h\n"
-    "smlal2 v8.4s, v27.8h, v0.8h\n"
-    "smlal v10.4s, v23.4h, v0.4h\n"
-    "smlal2 v9.4s, v23.8h, v0.8h\n"
-    "ldr d0, [x3, #0xa0]\n"
-    "ssubl v0.8h, v0.8b, v13.8b\n"
-    "smlal v15.4s, v30.4h, v1.4h\n"
-    "smlal2 v20.4s, v30.8h, v1.8h\n"
-    "ldr d30, [x28, x10]\n"
-    "ssubl v30.8h, v30.8b, v7.8b\n"
-    "smlal v18.4s, v26.4h, v1.4h\n"
-    "smlal2 v5.4s, v26.8h, v1.8h\n"
-    "smlal v11.4s, v23.4h, v1.4h\n"
-    "smlal2 v8.4s, v23.8h, v1.8h\n"
-    "smlal v10.4s, v31.4h, v1.4h\n"
-    "smlal2 v9.4s, v31.8h, v1.8h\n"
-    "ldr d1, [x3, #0xa8]\n"
-    "ssubl v1.8h, v1.8b, v13.8b\n"
-    "smlal v15.4s, v26.4h, v2.4h\n"
-    "smlal2 v20.4s, v26.8h, v2.8h\n"
-    "ldr d26, [x7, x10]\n"
-    "ssubl v26.8h, v26.8b, v7.8b\n"
-    "smlal v18.4s, v25.4h, v2.4h\n"
-    "smlal2 v5.4s, v25.8h, v2.8h\n"
-    "smlal v11.4s, v31.4h, v2.4h\n"
-    "smlal2 v8.4s, v31.8h, v2.8h\n"
-    "smlal v10.4s, v30.4h, v2.4h\n"
-    "smlal2 v9.4s, v30.8h, v2.8h\n"
-    "ldr d2, [x3, #0xb0]\n"
-    "ssubl v2.8h, v2.8b, v13.8b\n"
+    "ldr d27, [x19, x24]\n"
+    "ssubl v27.8h, v27.8b, v9.8b\n"
+    "smlal v17.4s, v25.4h, v2.4h\n"
+    "smlal v10.4s, v23.4h, v2.4h\n"
+    "ldr d1, [x23, #0x30]\n"
+    "ssubl v1.8h, v1.8b, v14.8b\n"
+    "smlal v6.4s, v31.4h, v2.4h\n"
+    "smlal2 v8.4s, v25.8h, v2.8h\n"
+    "ldr x21, [x20, #0xa0]\n"
+    "ldr x2, [x20, #0xa8]\n"
+    "smlal2 v7.4s, v23.8h, v2.8h\n"
     "smlal v15.4s, v25.4h, v3.4h\n"
-    "smlal2 v20.4s, v25.8h, v3.8h\n"
-    "ldr d25, [x26, x10]\n"
-    "ssubl v25.8h, v25.8b, v7.8b\n"
-    "smlal v18.4s, v24.4h, v3.4h\n"
-    "smlal2 v5.4s, v24.8h, v3.8h\n"
-    "smlal v11.4s, v30.4h, v3.4h\n"
-    "smlal2 v8.4s, v30.8h, v3.8h\n"
-    "smlal v10.4s, v28.4h, v3.4h\n"
-    "smlal2 v9.4s, v28.8h, v3.8h\n"
-    "ldr d3, [x3, #0xb8]\n"
-    "ssubl v3.8h, v3.8b, v13.8b\n"
+    "ldr x13, [x20, #0xb0]\n"
+    "ldr x9, [x20, #0xb8]\n"
+    "smlal2 v16.4s, v25.8h, v3.8h\n"
+    "smlal2 v5.4s, v31.8h, v2.8h\n"
+    "ldr d25, [x27, x24]\n"
+    "ssubl v25.8h, v25.8b, v9.8b\n"
+    "smlal v17.4s, v24.4h, v3.4h\n"
+    "smlal v10.4s, v31.4h, v3.4h\n"
+    "ldr d2, [x23, #0x38]\n"
+    "ssubl v2.8h, v2.8b, v14.8b\n"
+    "smlal v6.4s, v30.4h, v3.4h\n"
+    "smlal2 v8.4s, v24.8h, v3.8h\n"
+    "ldr x19, [x20, #0xc0]\n"
+    "ldr x28, [x20, #0xc8]\n"
+    "smlal2 v7.4s, v31.8h, v3.8h\n"
     "smlal v15.4s, v24.4h, v4.4h\n"
-    "smlal2 v20.4s, v24.8h, v4.8h\n"
-    "ldr d24, [x23, x10]\n"
-    "ssubl v24.8h, v24.8b, v7.8b\n"
-    "smlal v18.4s, v22.4h, v4.4h\n"
-    "smlal2 v5.4s, v22.8h, v4.8h\n"
-    "smlal v11.4s, v28.4h, v4.4h\n"
-    "smlal2 v8.4s, v28.8h, v4.8h\n"
-    "smlal v10.4s, v26.4h, v4.4h\n"
-    "smlal2 v9.4s, v26.8h, v4.8h\n"
-    "ldr d4, [x3, #0xc0]\n"
-    "add x3, x3, #0xc8\n"
-    "smlal v15.4s, v27.4h, v0.4h\n"
-    "ssubl v4.8h, v4.8b, v13.8b\n"
-    "smlal2 v20.4s, v27.8h, v0.8h\n"
-    "ldr d27, [x22, x10]\n"
-    "smlal v18.4s, v23.4h, v0.4h\n"
-    "ssubl v27.8h, v27.8b, v7.8b\n"
-    "smlal2 v5.4s, v23.8h, v0.8h\n"
-    "smlal v11.4s, v25.4h, v0.4h\n"
-    "smlal2 v8.4s, v25.8h, v0.8h\n"
-    "ldr d25, [x20, x10]\n"
-    "ssubl v25.8h, v25.8b, v7.8b\n"
-    "smlal v10.4s, v24.4h, v0.4h\n"
-    "smlal2 v9.4s, v24.8h, v0.8h\n"
-    "smlal v15.4s, v23.4h, v1.4h\n"
-    "smlal2 v20.4s, v23.8h, v1.8h\n"
-    "smlal v18.4s, v31.4h, v1.4h\n"
-    "smlal2 v5.4s, v31.8h, v1.8h\n"
-    "smlal v11.4s, v24.4h, v1.4h\n"
-    "smlal2 v8.4s, v24.8h, v1.8h\n"
-    "ldr d24, [x13, x10]\n"
-    "ssubl v24.8h, v24.8b, v7.8b\n"
-    "smlal v10.4s, v27.4h, v1.4h\n"
-    "smlal2 v9.4s, v27.8h, v1.8h\n"
-    "smlal v15.4s, v31.4h, v2.4h\n"
-    "smlal2 v20.4s, v31.8h, v2.8h\n"
-    "smlal v18.4s, v30.4h, v2.4h\n"
-    "smlal2 v5.4s, v30.8h, v2.8h\n"
-    "smlal v11.4s, v27.4h, v2.4h\n"
-    "smlal2 v8.4s, v27.8h, v2.8h\n"
-    "ldr d27, [x21, x10]\n"
-    "add x10, x10, #0x8\n"
-    "smlal v10.4s, v25.4h, v2.4h\n"
-    "ssubl v27.8h, v27.8b, v7.8b\n"
-    "smlal2 v9.4s, v25.8h, v2.8h\n"
-    "smlal v15.4s, v30.4h, v3.4h\n"
-    "smlal2 v20.4s, v30.8h, v3.8h\n"
-    "smlal v18.4s, v28.4h, v3.4h\n"
-    "smlal2 v5.4s, v28.8h, v3.8h\n"
-    "smlal v11.4s, v25.4h, v3.4h\n"
-    "smlal2 v8.4s, v25.8h, v3.8h\n"
-    "smlal v10.4s, v24.4h, v3.4h\n"
-    "smlal2 v9.4s, v24.8h, v3.8h\n"
-    "smlal v15.4s, v28.4h, v4.4h\n"
-    "smlal2 v20.4s, v28.8h, v4.8h\n"
-    "smlal v18.4s, v26.4h, v4.4h\n"
+    "ldr x6, [x20, #0xd0]\n"
+    "ldr x27, [x20, #0xd8]\n"
+    "smlal2 v16.4s, v24.8h, v4.8h\n"
+    "smlal2 v5.4s, v30.8h, v3.8h\n"
+    "ldr d24, [x5, x24]\n"
+    "ssubl v24.8h, v24.8b, v9.8b\n"
+    "smlal v17.4s, v27.4h, v4.4h\n"
+    "smlal v10.4s, v30.4h, v4.4h\n"
+    "ldr d3, [x23, #0x40]\n"
+    "ssubl v3.8h, v3.8b, v14.8b\n"
+    "smlal v6.4s, v26.4h, v4.4h\n"
+    "smlal2 v8.4s, v27.8h, v4.8h\n"
+    "ldr d27, [x11, x24]\n"
+    "ssubl v27.8h, v27.8b, v9.8b\n"
+    "smlal2 v7.4s, v30.8h, v4.8h\n"
+    "smlal v15.4s, v29.4h, v0.4h\n"
+    "ldr x11, [x20, #0xe0]\n"
+    "ldr x17, [x20, #0xe8]\n"
+    "smlal2 v16.4s, v29.8h, v0.8h\n"
     "smlal2 v5.4s, v26.8h, v4.8h\n"
-    "smlal v11.4s, v24.4h, v4.4h\n"
-    "smlal2 v8.4s, v24.8h, v4.8h\n"
-    "smlal v10.4s, v27.4h, v4.4h\n"
-    "smlal2 v9.4s, v27.8h, v4.8h\n"
-    "sqrdmulh v15.4s, v15.4s, v6.4s\n"
-    "sqrdmulh v20.4s, v20.4s, v17.4s\n"
-    "sqrdmulh v18.4s, v18.4s, v6.4s\n"
-    "sqrdmulh v5.4s, v5.4s, v17.4s\n"
-    "and v1.16b, v15.16b, v21.16b\n"
-    "sshr v1.4s, v1.4s, #0x1f\n"
-    "and v29.16b, v20.16b, v14.16b\n"
-    "and v3.16b, v18.16b, v21.16b\n"
-    "sshr v29.4s, v29.4s, #0x1f\n"
-    "and v2.16b, v5.16b, v14.16b\n"
-    "sqrdmulh v11.4s, v11.4s, v6.4s\n"
-    "sshr v3.4s, v3.4s, #0x1f\n"
-    "sqrdmulh v8.4s, v8.4s, v17.4s\n"
+    "ldr d4, [x23, #0x48]\n"
+    "ssubl v4.8h, v4.8b, v14.8b\n"
+    "smlal v17.4s, v28.4h, v0.4h\n"
+    "smlal v10.4s, v22.4h, v0.4h\n"
+    "ldr x5, [x20, #0xf0]\n"
+    "ldr q12, [x10, #0x0]\n"
+    "smlal v6.4s, v25.4h, v0.4h\n"
+    "smlal2 v8.4s, v28.8h, v0.8h\n"
+    "ldr q19, [x1, #0x0]\n"
+    "ldr q20, [x10, #0x10]\n"
+    "smlal2 v7.4s, v22.8h, v0.8h\n"
+    "smlal v15.4s, v28.4h, v1.4h\n"
+    "ldr q29, [x1, #0x10]\n"
+    "subs x3, x3, #0x1\n"
+    "smlal2 v16.4s, v28.8h, v1.8h\n"
+    "smlal2 v5.4s, v25.8h, v0.8h\n"
+    "ldr d28, [x26, x24]\n"
+    "ldr d0, [x23, #0x50]\n"
+    "smlal v17.4s, v23.4h, v1.4h\n"
+    "smlal v10.4s, v25.4h, v1.4h\n"
+    "ssubl v28.8h, v28.8b, v9.8b\n"
+    "ldr x25, [x20, #0xf8]\n"
+    "smlal v6.4s, v24.4h, v1.4h\n"
+    "smlal2 v8.4s, v23.8h, v1.8h\n"
+    "ssubl v0.8h, v0.8b, v14.8b\n"
+    "add x10, x10, #0x20\n"
+    "smlal2 v7.4s, v25.8h, v1.8h\n"
+    "smlal v15.4s, v23.4h, v2.4h\n"
+    "add x1, x1, #0x20\n"
+    "smlal2 v16.4s, v23.8h, v2.8h\n"
+    "ldr d23, [x12, x24]\n"
+    "smlal2 v5.4s, v24.8h, v1.8h\n"
+    "ssubl v23.8h, v23.8b, v9.8b\n"
+    "smlal v17.4s, v31.4h, v2.4h\n"
+    "smlal v10.4s, v24.4h, v2.4h\n"
+    "ldr d1, [x23, #0x58]\n"
+    "ssubl v1.8h, v1.8b, v14.8b\n"
+    "smlal v6.4s, v27.4h, v2.4h\n"
+    "smlal2 v8.4s, v31.8h, v2.8h\n"
+    "ldr x26, [x20, #0x100]\n"
+    "smlal2 v7.4s, v24.8h, v2.8h\n"
+    "smlal v15.4s, v31.4h, v3.4h\n"
+    "smlal2 v16.4s, v31.8h, v3.8h\n"
+    "smlal2 v5.4s, v27.8h, v2.8h\n"
+    "ldr d31, [x14, x24]\n"
+    "ssubl v31.8h, v31.8b, v9.8b\n"
+    "smlal v17.4s, v30.4h, v3.4h\n"
+    "smlal v10.4s, v27.4h, v3.4h\n"
+    "ldr d2, [x23, #0x60]\n"
+    "ssubl v2.8h, v2.8b, v14.8b\n"
+    "smlal v6.4s, v23.4h, v3.4h\n"
+    "smlal2 v8.4s, v30.8h, v3.8h\n"
+    "ldr x12, [x20, #0x108]\n"
+    "smlal2 v7.4s, v27.8h, v3.8h\n"
+    "smlal v15.4s, v30.4h, v4.4h\n"
+    "smlal2 v16.4s, v30.8h, v4.8h\n"
+    "ldr d30, [x15, x24]\n"
+    "smlal2 v5.4s, v23.8h, v3.8h\n"
+    "ssubl v30.8h, v30.8b, v9.8b\n"
+    "smlal v17.4s, v26.4h, v4.4h\n"
+    "smlal v10.4s, v23.4h, v4.4h\n"
+    "ldr d3, [x23, #0x68]\n"
+    "ssubl v3.8h, v3.8b, v14.8b\n"
+    "smlal v6.4s, v28.4h, v4.4h\n"
+    "smlal2 v8.4s, v26.8h, v4.8h\n"
+    "ldr d26, [x21, x24]\n"
+    "ssubl v26.8h, v26.8b, v9.8b\n"
+    "smlal2 v7.4s, v23.8h, v4.8h\n"
+    "smlal v15.4s, v22.4h, v0.4h\n"
+    "ldr x14, [x20, #0x110]\n"
+    "ldr x21, [x20, #0x118]\n"
+    "smlal2 v16.4s, v22.8h, v0.8h\n"
+    "smlal2 v5.4s, v28.8h, v4.8h\n"
+    "ldr d4, [x23, #0x70]\n"
+    "ldr d22, [x9, x24]\n"
+    "smlal v17.4s, v25.4h, v0.4h\n"
+    "smlal v10.4s, v31.4h, v0.4h\n"
+    "ssubl v4.8h, v4.8b, v14.8b\n"
+    "smlal v6.4s, v30.4h, v0.4h\n"
+    "smlal2 v8.4s, v25.8h, v0.8h\n"
+    "ssubl v22.8h, v22.8b, v9.8b\n"
+    "smlal2 v7.4s, v31.8h, v0.8h\n"
+    "smlal v15.4s, v25.4h, v1.4h\n"
+    "smlal2 v16.4s, v25.8h, v1.8h\n"
+    "ldr d25, [x2, x24]\n"
+    "smlal2 v5.4s, v30.8h, v0.8h\n"
+    "ssubl v25.8h, v25.8b, v9.8b\n"
+    "smlal v17.4s, v24.4h, v1.4h\n"
+    "smlal v10.4s, v30.4h, v1.4h\n"
+    "ldr d0, [x23, #0x78]\n"
+    "ssubl v0.8h, v0.8b, v14.8b\n"
+    "smlal v6.4s, v26.4h, v1.4h\n"
+    "smlal2 v8.4s, v24.8h, v1.8h\n"
+    "smlal2 v7.4s, v30.8h, v1.8h\n"
+    "smlal v15.4s, v24.4h, v2.4h\n"
+    "smlal2 v16.4s, v24.8h, v2.8h\n"
+    "ldr d24, [x13, x24]\n"
+    "smlal2 v5.4s, v26.8h, v1.8h\n"
+    "ssubl v24.8h, v24.8b, v9.8b\n"
+    "smlal v17.4s, v27.4h, v2.4h\n"
+    "smlal v10.4s, v26.4h, v2.4h\n"
+    "ldr d1, [x23, #0x80]\n"
+    "ssubl v1.8h, v1.8b, v14.8b\n"
+    "smlal v6.4s, v25.4h, v2.4h\n"
+    "smlal2 v8.4s, v27.8h, v2.8h\n"
+    "smlal2 v7.4s, v26.8h, v2.8h\n"
+    "smlal v15.4s, v27.4h, v3.4h\n"
+    "smlal2 v16.4s, v27.8h, v3.8h\n"
+    "smlal2 v5.4s, v25.8h, v2.8h\n"
+    "ldr d27, [x19, x24]\n"
+    "ssubl v27.8h, v27.8b, v9.8b\n"
+    "smlal v17.4s, v23.4h, v3.4h\n"
+    "smlal v10.4s, v25.4h, v3.4h\n"
+    "ldr d2, [x23, #0x88]\n"
+    "ssubl v2.8h, v2.8b, v14.8b\n"
+    "smlal v6.4s, v24.4h, v3.4h\n"
+    "smlal2 v8.4s, v23.8h, v3.8h\n"
+    "smlal2 v7.4s, v25.8h, v3.8h\n"
+    "smlal v15.4s, v23.4h, v4.4h\n"
+    "smlal2 v16.4s, v23.8h, v4.8h\n"
+    "ldr d23, [x28, x24]\n"
+    "smlal2 v5.4s, v24.8h, v3.8h\n"
+    "ssubl v23.8h, v23.8b, v9.8b\n"
+    "smlal v17.4s, v28.4h, v4.4h\n"
+    "smlal v10.4s, v24.4h, v4.4h\n"
+    "ldr d3, [x23, #0x90]\n"
+    "ssubl v3.8h, v3.8b, v14.8b\n"
+    "smlal v6.4s, v22.4h, v4.4h\n"
+    "smlal2 v8.4s, v28.8h, v4.8h\n"
+    "ldr d28, [x11, x24]\n"
+    "ssubl v28.8h, v28.8b, v9.8b\n"
+    "smlal2 v7.4s, v24.8h, v4.8h\n"
+    "smlal v15.4s, v31.4h, v0.4h\n"
+    "smlal2 v16.4s, v31.8h, v0.8h\n"
+    "ldr d31, [x6, x24]\n"
+    "smlal2 v5.4s, v22.8h, v4.8h\n"
+    "ssubl v31.8h, v31.8b, v9.8b\n"
+    "smlal v17.4s, v30.4h, v0.4h\n"
+    "smlal v10.4s, v27.4h, v0.4h\n"
+    "ldr d4, [x23, #0x98]\n"
+    "ssubl v4.8h, v4.8b, v14.8b\n"
+    "smlal v6.4s, v23.4h, v0.4h\n"
+    "smlal2 v8.4s, v30.8h, v0.8h\n"
+    "smlal2 v7.4s, v27.8h, v0.8h\n"
+    "smlal v15.4s, v30.4h, v1.4h\n"
+    "smlal2 v16.4s, v30.8h, v1.8h\n"
+    "ldr d30, [x27, x24]\n"
+    "smlal2 v5.4s, v23.8h, v0.8h\n"
+    "ssubl v30.8h, v30.8b, v9.8b\n"
+    "smlal v17.4s, v26.4h, v1.4h\n"
+    "smlal v10.4s, v23.4h, v1.4h\n"
+    "ldr d0, [x23, #0xa0]\n"
+    "ssubl v0.8h, v0.8b, v14.8b\n"
+    "smlal v6.4s, v31.4h, v1.4h\n"
+    "smlal2 v8.4s, v26.8h, v1.8h\n"
+    "smlal2 v7.4s, v23.8h, v1.8h\n"
+    "smlal v15.4s, v26.4h, v2.4h\n"
+    "smlal2 v16.4s, v26.8h, v2.8h\n"
+    "smlal2 v5.4s, v31.8h, v1.8h\n"
+    "ldr d26, [x17, x24]\n"
+    "ssubl v26.8h, v26.8b, v9.8b\n"
+    "smlal v17.4s, v25.4h, v2.4h\n"
+    "smlal v10.4s, v31.4h, v2.4h\n"
+    "ldr d1, [x23, #0xa8]\n"
+    "ssubl v1.8h, v1.8b, v14.8b\n"
+    "smlal v6.4s, v30.4h, v2.4h\n"
+    "smlal2 v8.4s, v25.8h, v2.8h\n"
+    "smlal2 v7.4s, v31.8h, v2.8h\n"
+    "smlal v15.4s, v25.4h, v3.4h\n"
+    "smlal2 v16.4s, v25.8h, v3.8h\n"
+    "smlal2 v5.4s, v30.8h, v2.8h\n"
+    "ldr d25, [x5, x24]\n"
+    "ssubl v25.8h, v25.8b, v9.8b\n"
+    "smlal v17.4s, v24.4h, v3.4h\n"
+    "smlal v10.4s, v30.4h, v3.4h\n"
+    "ldr d2, [x23, #0xb0]\n"
+    "ssubl v2.8h, v2.8b, v14.8b\n"
+    "smlal v6.4s, v28.4h, v3.4h\n"
+    "smlal2 v8.4s, v24.8h, v3.8h\n"
+    "smlal2 v7.4s, v30.8h, v3.8h\n"
+    "smlal v15.4s, v24.4h, v4.4h\n"
+    "smlal2 v16.4s, v24.8h, v4.8h\n"
+    "ldr d24, [x25, x24]\n"
+    "smlal2 v5.4s, v28.8h, v3.8h\n"
+    "ssubl v24.8h, v24.8b, v9.8b\n"
+    "smlal v17.4s, v22.4h, v4.4h\n"
+    "smlal v10.4s, v28.4h, v4.4h\n"
+    "ldr d3, [x23, #0xb8]\n"
+    "ssubl v3.8h, v3.8b, v14.8b\n"
+    "smlal v6.4s, v26.4h, v4.4h\n"
+    "smlal2 v7.4s, v28.8h, v4.8h\n"
+    "smlal v15.4s, v27.4h, v0.4h\n"
+    "smlal2 v16.4s, v27.8h, v0.8h\n"
+    "ldr d27, [x26, x24]\n"
+    "ssubl v27.8h, v27.8b, v9.8b\n"
+    "smlal2 v8.4s, v22.8h, v4.8h\n"
+    "smlal2 v5.4s, v26.8h, v4.8h\n"
+    "ldr d4, [x23, #0xc0]\n"
+    "ssubl v4.8h, v4.8b, v14.8b\n"
+    "smlal v17.4s, v23.4h, v0.4h\n"
+    "smlal v10.4s, v25.4h, v0.4h\n"
+    "add x23, x23, #0xc8\n"
+    "smlal v6.4s, v24.4h, v0.4h\n"
+    "smlal2 v7.4s, v25.8h, v0.8h\n"
+    "ldr d25, [x12, x24]\n"
+    "ssubl v25.8h, v25.8b, v9.8b\n"
+    "smlal2 v8.4s, v23.8h, v0.8h\n"
+    "smlal2 v5.4s, v24.8h, v0.8h\n"
+    "smlal v15.4s, v23.4h, v1.4h\n"
+    "smlal v17.4s, v31.4h, v1.4h\n"
+    "smlal v10.4s, v24.4h, v1.4h\n"
+    "smlal v6.4s, v27.4h, v1.4h\n"
+    "smlal2 v7.4s, v24.8h, v1.8h\n"
+    "ldr d24, [x14, x24]\n"
+    "smlal2 v16.4s, v23.8h, v1.8h\n"
+    "ssubl v24.8h, v24.8b, v9.8b\n"
+    "smlal2 v8.4s, v31.8h, v1.8h\n"
+    "smlal2 v5.4s, v27.8h, v1.8h\n"
+    "smlal v15.4s, v31.4h, v2.4h\n"
+    "smlal v17.4s, v30.4h, v2.4h\n"
+    "smlal v10.4s, v27.4h, v2.4h\n"
+    "smlal v6.4s, v25.4h, v2.4h\n"
+    "smlal2 v7.4s, v27.8h, v2.8h\n"
+    "ldr d27, [x21, x24]\n"
+    "smlal2 v16.4s, v31.8h, v2.8h\n"
+    "ssubl v27.8h, v27.8b, v9.8b\n"
+    "smlal2 v8.4s, v30.8h, v2.8h\n"
+    "smlal2 v5.4s, v25.8h, v2.8h\n"
+    "add x24, x24, #0x8\n"
+    "smlal v15.4s, v30.4h, v3.4h\n"
+    "smlal v17.4s, v28.4h, v3.4h\n"
+    "smlal v10.4s, v25.4h, v3.4h\n"
+    "smlal v6.4s, v24.4h, v3.4h\n"
+    "smlal2 v16.4s, v30.8h, v3.8h\n"
+    "smlal2 v8.4s, v28.8h, v3.8h\n"
+    "smlal2 v7.4s, v25.8h, v3.8h\n"
+    "smlal2 v5.4s, v24.8h, v3.8h\n"
+    "smlal v15.4s, v28.4h, v4.4h\n"
+    "smlal v17.4s, v26.4h, v4.4h\n"
+    "sqdmulh v15.4s, v15.4s, v12.4s\n"
+    "smlal v10.4s, v24.4h, v4.4h\n"
+    "smlal v6.4s, v27.4h, v4.4h\n"
+    "sqdmulh v17.4s, v17.4s, v12.4s\n"
+    "smlal2 v16.4s, v28.8h, v4.8h\n"
+    "smlal2 v8.4s, v26.8h, v4.8h\n"
+    "sqdmulh v10.4s, v10.4s, v12.4s\n"
+    "smlal2 v7.4s, v24.8h, v4.8h\n"
+    "smlal2 v5.4s, v27.8h, v4.8h\n"
+    "sqdmulh v6.4s, v6.4s, v12.4s\n"
+    "and v23.16b, v15.16b, v19.16b\n"
+    "sqdmulh v16.4s, v16.4s, v20.4s\n"
+    "and v22.16b, v17.16b, v19.16b\n"
+    "sqdmulh v8.4s, v8.4s, v20.4s\n"
+    "and v21.16b, v10.16b, v19.16b\n"
+    "sqdmulh v7.4s, v7.4s, v20.4s\n"
+    "and v26.16b, v6.16b, v19.16b\n"
+    "sqdmulh v5.4s, v5.4s, v20.4s\n"
+    "sshr v23.4s, v23.4s, #0x1f\n"
+    "and v4.16b, v16.16b, v29.16b\n"
+    "sshr v22.4s, v22.4s, #0x1f\n"
+    "and v2.16b, v8.16b, v29.16b\n"
+    "sshr v21.4s, v21.4s, #0x1f\n"
+    "and v3.16b, v7.16b, v29.16b\n"
+    "sshr v26.4s, v26.4s, #0x1f\n"
+    "and v25.16b, v5.16b, v29.16b\n"
+    "sqadd v15.4s, v15.4s, v23.4s\n"
+    "sshr v4.4s, v4.4s, #0x1f\n"
+    "sqadd v17.4s, v17.4s, v22.4s\n"
     "sshr v2.4s, v2.4s, #0x1f\n"
-    "sqadd v15.4s, v15.4s, v1.4s\n"
-    "sqrdmulh v10.4s, v10.4s, v6.4s\n"
-    "and v0.16b, v11.16b, v21.16b\n"
-    "sshr v0.4s, v0.4s, #0x1f\n"
-    "srshl v15.4s, v15.4s, v21.4s\n"
-    "sqadd v20.4s, v20.4s, v29.4s\n"
-    "sqadd v18.4s, v18.4s, v3.4s\n"
-    "sqadd v5.4s, v5.4s, v2.4s\n"
-    "and v27.16b, v8.16b, v14.16b\n"
-    "sshr v27.4s, v27.4s, #0x1f\n"
-    "add v15.4s, v15.4s, v19.4s\n"
-    "srshl v20.4s, v20.4s, v14.4s\n"
-    "srshl v18.4s, v18.4s, v21.4s\n"
-    "srshl v5.4s, v5.4s, v14.4s\n"
-    "smin v15.4s, v15.4s, v12.4s\n"
-    "add v20.4s, v20.4s, v19.4s\n"
-    "add v18.4s, v18.4s, v19.4s\n"
-    "smax v15.4s, v15.4s, v16.4s\n"
-    "smin v20.4s, v20.4s, v12.4s\n"
-    "smin v18.4s, v18.4s, v12.4s\n"
-    "add v5.4s, v5.4s, v19.4s\n"
-    "smax v20.4s, v20.4s, v16.4s\n"
-    "smax v18.4s, v18.4s, v16.4s\n"
-    "smin v5.4s, v5.4s, v12.4s\n"
-    "uzp1 v15.16b, v15.16b, v20.16b\n"
-    "sqadd v11.4s, v11.4s, v0.4s\n"
+    "sqadd v10.4s, v10.4s, v21.4s\n"
+    "sshr v3.4s, v3.4s, #0x1f\n"
+    "sqadd v6.4s, v6.4s, v26.4s\n"
+    "sshr v25.4s, v25.4s, #0x1f\n"
+    "srshl v15.4s, v15.4s, v19.4s\n"
+    "sqadd v16.4s, v16.4s, v4.4s\n"
+    "srshl v17.4s, v17.4s, v19.4s\n"
+    "sqadd v8.4s, v8.4s, v2.4s\n"
+    "srshl v10.4s, v10.4s, v19.4s\n"
+    "sqadd v7.4s, v7.4s, v3.4s\n"
+    "srshl v6.4s, v6.4s, v19.4s\n"
+    "sqadd v5.4s, v5.4s, v25.4s\n"
+    "srshl v16.4s, v16.4s, v29.4s\n"
+    "sqxtn v15.4h, v15.4s\n"
+    "srshl v8.4s, v8.4s, v29.4s\n"
+    "sqxtn v17.4h, v17.4s\n"
+    "srshl v7.4s, v7.4s, v29.4s\n"
+    "sqxtn v10.4h, v10.4s\n"
+    "srshl v5.4s, v5.4s, v29.4s\n"
+    "sqxtn v6.4h, v6.4s\n"
+    "sqxtn2 v15.8h, v16.4s\n"
+    "sqxtn2 v17.8h, v8.4s\n"
+    "sqxtn2 v10.8h, v7.4s\n"
+    "sqxtn2 v6.8h, v5.4s\n"
+    "sqadd v15.8h, v15.8h, v18.8h\n"
+    "sqadd v17.8h, v17.8h, v18.8h\n"
+    "sqadd v10.8h, v10.8h, v18.8h\n"
+    "sqadd v6.8h, v6.8h, v18.8h\n"
+    "smax v15.8h, v15.8h, v11.8h\n"
+    "smax v17.8h, v17.8h, v11.8h\n"
+    "smax v10.8h, v10.8h, v11.8h\n"
+    "smax v6.8h, v6.8h, v11.8h\n"
+    "smin v15.8h, v15.8h, v13.8h\n"
+    "smin v17.8h, v17.8h, v13.8h\n"
+    "smin v10.8h, v10.8h, v13.8h\n"
+    "smin v6.8h, v6.8h, v13.8h\n"
     "uzp1 v15.16b, v15.16b, v15.16b\n"
-    "str d15, [x17, x1]\n"
-    "smax v5.4s, v5.4s, v16.4s\n"
-    "sqadd v8.4s, v8.4s, v27.4s\n"
-    "srshl v11.4s, v11.4s, v21.4s\n"
-    "and v30.16b, v10.16b, v21.16b\n"
-    "sshr v30.4s, v30.4s, #0x1f\n"
-    "uzp1 v18.16b, v18.16b, v5.16b\n"
-    "add v11.4s, v11.4s, v19.4s\n"
-    "srshl v8.4s, v8.4s, v14.4s\n"
-    "uzp1 v18.16b, v18.16b, v18.16b\n"
-    "str d18, [x16, x1]\n"
-    "smin v11.4s, v11.4s, v12.4s\n"
-    "sqrdmulh v9.4s, v9.4s, v17.4s\n"
-    "add v8.4s, v8.4s, v19.4s\n"
-    "sqadd v10.4s, v10.4s, v30.4s\n"
-    "smax v11.4s, v11.4s, v16.4s\n"
-    "smin v8.4s, v8.4s, v12.4s\n"
-    "and v6.16b, v9.16b, v14.16b\n"
-    "sshr v6.4s, v6.4s, #0x1f\n"
-    "smax v8.4s, v8.4s, v16.4s\n"
-    "srshl v10.4s, v10.4s, v21.4s\n"
-    "uzp1 v11.16b, v11.16b, v8.16b\n"
-    "add v10.4s, v10.4s, v19.4s\n"
-    "uzp1 v11.16b, v11.16b, v11.16b\n"
-    "str d11, [x6, x1]\n"
-    "smin v10.4s, v10.4s, v12.4s\n"
-    "sqadd v9.4s, v9.4s, v6.4s\n"
-    "smax v10.4s, v10.4s, v16.4s\n"
-    "srshl v9.4s, v9.4s, v14.4s\n"
-    "add v9.4s, v9.4s, v19.4s\n"
-    "smin v9.4s, v9.4s, v12.4s\n"
-    "smax v9.4s, v9.4s, v16.4s\n"
-    "uzp1 v10.16b, v10.16b, v9.16b\n"
+    "uzp1 v17.16b, v17.16b, v17.16b\n"
+    "str d15, [x16, x22]\n"
     "uzp1 v10.16b, v10.16b, v10.16b\n"
-    "str d10, [x8, x1]\n"
-    "add x1, x1, #0x8\n"
-    "ldr x12, [%x[params], %[offsetof_Params_bias]]\n"
-    "ldr q15, [x12, #0x0]\n"
-    "mov v18.16b, v15.16b\n"
-    "ldr q20, [x12, #0x10]\n"
-    "add x12, x12, #0x20\n"
-    "mov v11.16b, v15.16b\n"
-    "str x12, [%x[params], %[offsetof_Params_bias]]\n"
+    "uzp1 v6.16b, v6.16b, v6.16b\n"
+    "str d17, [x8, x22]\n"
+    "str d10, [x4, x22]\n"
+    "str d6, [x7, x22]\n"
+    "ldr x19, [%x[params], %[offsetof_Params_bias]]\n"
+    "ldr q15, [x19, #0x0]\n"
+    "add x22, x22, #0x8\n"
+    "ldr q16, [x19, #0x10]\n"
+    "add x19, x19, #0x20\n"
+    "str x19, [%x[params], %[offsetof_Params_bias]]\n"
+    "ldr d0, [x23, #0x0]\n"
+    "ldr d1, [x23, #0x8]\n"
+    "ldr d2, [x23, #0x10]\n"
+    "mov v17.16b, v15.16b\n"
+    "mov v8.16b, v16.16b\n"
+    "ldr d3, [x23, #0x18]\n"
+    "ldr d4, [x23, #0x20]\n"
     "mov v10.16b, v15.16b\n"
-    "ldr d0, [x3, #0x0]\n"
-    "ssubl v0.8h, v0.8b, v13.8b\n"
-    "mov v5.16b, v20.16b\n"
-    "ldr d1, [x3, #0x8]\n"
-    "mov v8.16b, v20.16b\n"
-    "ldr d2, [x3, #0x10]\n"
-    "ssubl v1.8h, v1.8b, v13.8b\n"
-    "mov v9.16b, v20.16b\n"
-    "ldr d3, [x3, #0x18]\n"
-    "ldr d4, [x3, #0x20]\n"
-    "ssubl v2.8h, v2.8b, v13.8b\n"
-    "ldp x28, x27, [x25, #0x0]\n"
-    "ssubl v3.8h, v3.8b, v13.8b\n"
-    "ldp x26, x13, [x25, #0x10]\n"
-    "ssubl v4.8h, v4.8b, v13.8b\n"
-    "ldp x24, x23, [x25, #0x20]\n"
-    "ldp x22, x21, [x25, #0x30]\n"
-    "ldp x20, x0, [x25, #0x40]\n"
-    "ldr d31, [x28, x10]\n"
-    "ssubl v31.8h, v31.8b, v7.8b\n"
-    "ldr d30, [x27, x10]\n"
-    "ldr d29, [x26, x10]\n"
-    "ssubl v30.8h, v30.8b, v7.8b\n"
-    "ldr d28, [x13, x10]\n"
-    "ldr d27, [x24, x10]\n"
-    "ssubl v29.8h, v29.8b, v7.8b\n"
-    "ldr d23, [x23, x10]\n"
-    "ssubl v28.8h, v28.8b, v7.8b\n"
-    "ldr d25, [x22, x10]\n"
-    "ldr d24, [x21, x10]\n"
-    "ssubl v27.8h, v27.8b, v7.8b\n"
-    "ldr d26, [x20, x10]\n"
-    "ssubl v23.8h, v23.8b, v7.8b\n"
-    "ldr d22, [x0, x10]\n"
-    "ssubl v25.8h, v25.8b, v7.8b\n"
-    "ssubl v24.8h, v24.8b, v7.8b\n"
-    "ssubl v26.8h, v26.8b, v7.8b\n"
-    "ssubl v22.8h, v22.8b, v7.8b\n"
+    "mov v7.16b, v16.16b\n"
+    "ldp x28, x6, [x20, #0x0]\n"
+    "ldp x26, x25, [x20, #0x10]\n"
+    "mov v6.16b, v15.16b\n"
+    "mov v5.16b, v16.16b\n"
+    "ldp x5, x2, [x20, #0x20]\n"
+    "ldp x27, x21, [x20, #0x30]\n"
+    "ssubl v0.8h, v0.8b, v14.8b\n"
+    "ssubl v1.8h, v1.8b, v14.8b\n"
+    "ldp x12, x19, [x20, #0x40]\n"
+    "ldr d31, [x28, x24]\n"
+    "ssubl v2.8h, v2.8b, v14.8b\n"
+    "ssubl v3.8h, v3.8b, v14.8b\n"
+    "ldr d30, [x6, x24]\n"
+    "ldr d29, [x26, x24]\n"
+    "ssubl v4.8h, v4.8b, v14.8b\n"
+    "ssubl v31.8h, v31.8b, v9.8b\n"
+    "ldr d28, [x25, x24]\n"
+    "ldr d27, [x5, x24]\n"
+    "ssubl v30.8h, v30.8b, v9.8b\n"
+    "ssubl v29.8h, v29.8b, v9.8b\n"
+    "ldr d23, [x2, x24]\n"
+    "ldr d25, [x27, x24]\n"
+    "ssubl v28.8h, v28.8b, v9.8b\n"
+    "ssubl v27.8h, v27.8b, v9.8b\n"
+    "ldr d24, [x21, x24]\n"
+    "ldr d26, [x12, x24]\n"
+    "ssubl v23.8h, v23.8b, v9.8b\n"
+    "ssubl v25.8h, v25.8b, v9.8b\n"
+    "ldr d22, [x19, x24]\n"
+    "ssubl v24.8h, v24.8b, v9.8b\n"
+    "ssubl v26.8h, v26.8b, v9.8b\n"
+    "ssubl v22.8h, v22.8b, v9.8b\n"
     "bgt 1b\n"
     "2:"  // Tail
     "smlal v15.4s, v31.4h, v0.4h\n"
-    "ldr x20, [x25, #0x50]\n"
-    "tst x4, #0x7\n"
-    "smlal2 v20.4s, v31.8h, v0.8h\n"
-    "ldr x28, [x25, #0x58]\n"
-    "smlal v18.4s, v30.4h, v0.4h\n"
-    "ldr x0, [x25, #0x60]\n"
-    "smlal2 v5.4s, v30.8h, v0.8h\n"
-    "ldr d31, [x20, x10]\n"
-    "ssubl v31.8h, v31.8b, v7.8b\n"
-    "smlal v11.4s, v29.4h, v0.4h\n"
-    "ldr x7, [x25, #0x68]\n"
-    "smlal2 v8.4s, v29.8h, v0.8h\n"
-    "ldr x26, [x25, #0x70]\n"
-    "smlal v10.4s, v28.4h, v0.4h\n"
-    "ldr x23, [x25, #0x78]\n"
-    "smlal2 v9.4s, v28.8h, v0.8h\n"
-    "ldr d0, [x3, #0x28]\n"
-    "ssubl v0.8h, v0.8b, v13.8b\n"
+    "smlal2 v16.4s, v31.8h, v0.8h\n"
+    "ldr x19, [x20, #0x50]\n"
+    "ldr d31, [x19, x24]\n"
+    "smlal v17.4s, v30.4h, v0.4h\n"
+    "smlal v10.4s, v29.4h, v0.4h\n"
+    "ldr x15, [x20, #0x58]\n"
+    "ssubl v31.8h, v31.8b, v9.8b\n"
+    "smlal v6.4s, v28.4h, v0.4h\n"
+    "smlal2 v8.4s, v30.8h, v0.8h\n"
+    "ldr x19, [x20, #0x60]\n"
+    "ldr x27, [x20, #0x68]\n"
+    "smlal2 v7.4s, v29.8h, v0.8h\n"
     "smlal v15.4s, v30.4h, v1.4h\n"
-    "ldr x20, [x25, #0x80]\n"
-    "smlal2 v20.4s, v30.8h, v1.8h\n"
-    "ldr d30, [x28, x10]\n"
-    "ssubl v30.8h, v30.8b, v7.8b\n"
-    "smlal v18.4s, v27.4h, v1.4h\n"
-    "ldr x22, [x25, #0x88]\n"
-    "smlal2 v5.4s, v27.8h, v1.8h\n"
-    "ldr x13, [x25, #0x90]\n"
-    "smlal v11.4s, v28.4h, v1.4h\n"
-    "ldr x21, [x25, #0x98]\n"
-    "smlal2 v8.4s, v28.8h, v1.8h\n"
-    "ldr x14, [x25, #0xa0]\n"
-    "smlal v10.4s, v23.4h, v1.4h\n"
-    "ldr x11, [x25, #0xa8]\n"
-    "smlal2 v9.4s, v23.8h, v1.8h\n"
-    "ldr d1, [x3, #0x30]\n"
-    "ssubl v1.8h, v1.8b, v13.8b\n"
-    "smlal v15.4s, v27.4h, v2.4h\n"
-    "ldr x24, [x25, #0xb0]\n"
-    "smlal2 v20.4s, v27.8h, v2.8h\n"
-    "ldr d27, [x0, x10]\n"
-    "ssubl v27.8h, v27.8b, v7.8b\n"
-    "smlal v18.4s, v25.4h, v2.4h\n"
-    "ldr x0, [x25, #0xb8]\n"
-    "smlal2 v5.4s, v25.8h, v2.8h\n"
-    "ldr x15, [x25, #0xc0]\n"
-    "smlal v11.4s, v23.4h, v2.4h\n"
-    "ldr x9, [x25, #0xc8]\n"
-    "smlal2 v8.4s, v23.8h, v2.8h\n"
-    "ldr x27, [x25, #0xd0]\n"
-    "smlal v10.4s, v31.4h, v2.4h\n"
-    "ldr x28, [x25, #0xd8]\n"
-    "smlal2 v9.4s, v31.8h, v2.8h\n"
-    "ldr d2, [x3, #0x38]\n"
-    "ssubl v2.8h, v2.8b, v13.8b\n"
-    "smlal v15.4s, v25.4h, v3.4h\n"
-    "ldr x12, [x25, #0xe0]\n"
-    "smlal2 v20.4s, v25.8h, v3.8h\n"
-    "ldr d25, [x7, x10]\n"
-    "ssubl v25.8h, v25.8b, v7.8b\n"
-    "smlal v18.4s, v24.4h, v3.4h\n"
-    "ldr x7, [x25, #0xe8]\n"
-    "smlal2 v5.4s, v24.8h, v3.8h\n"
-    "ldr q6, [x2, #0x0]\n"
-    "smlal v11.4s, v31.4h, v3.4h\n"
-    "ldr q21, [x5, #0x0]\n"
-    "smlal2 v8.4s, v31.8h, v3.8h\n"
-    "ldr q17, [x2, #0x10]\n"
-    "add x2, x2, #0x20\n"
-    "smlal v10.4s, v30.4h, v3.4h\n"
-    "ldr q14, [x5, #0x10]\n"
-    "add x5, x5, #0x20\n"
-    "smlal2 v9.4s, v30.8h, v3.8h\n"
-    "ldr d3, [x3, #0x40]\n"
-    "ssubl v3.8h, v3.8b, v13.8b\n"
-    "smlal v15.4s, v24.4h, v4.4h\n"
-    "smlal2 v20.4s, v24.8h, v4.8h\n"
-    "ldr d24, [x26, x10]\n"
-    "ssubl v24.8h, v24.8b, v7.8b\n"
-    "smlal v18.4s, v27.4h, v4.4h\n"
-    "ldr x26, [x25, #0xf0]\n"
-    "smlal2 v5.4s, v27.8h, v4.8h\n"
-    "ldr d27, [x23, x10]\n"
-    "ssubl v27.8h, v27.8b, v7.8b\n"
-    "smlal v11.4s, v30.4h, v4.4h\n"
-    "ldr x23, [x25, #0xf8]\n"
-    "smlal2 v8.4s, v30.8h, v4.8h\n"
-    "smlal v10.4s, v26.4h, v4.4h\n"
-    "smlal2 v9.4s, v26.8h, v4.8h\n"
-    "ldr d4, [x3, #0x48]\n"
-    "ssubl v4.8h, v4.8b, v13.8b\n"
-    "smlal v15.4s, v29.4h, v0.4h\n"
-    "smlal2 v20.4s, v29.8h, v0.8h\n"
-    "smlal v18.4s, v28.4h, v0.4h\n"
+    "ldr x5, [x20, #0x70]\n"
+    "ldr x11, [x20, #0x78]\n"
+    "smlal2 v16.4s, v30.8h, v1.8h\n"
     "smlal2 v5.4s, v28.8h, v0.8h\n"
-    "smlal v11.4s, v22.4h, v0.4h\n"
-    "smlal2 v8.4s, v22.8h, v0.8h\n"
-    "smlal v10.4s, v25.4h, v0.4h\n"
-    "smlal2 v9.4s, v25.8h, v0.8h\n"
-    "ldr d0, [x3, #0x50]\n"
-    "ssubl v0.8h, v0.8b, v13.8b\n"
-    "smlal v15.4s, v28.4h, v1.4h\n"
-    "smlal2 v20.4s, v28.8h, v1.8h\n"
-    "ldr d28, [x22, x10]\n"
-    "ssubl v28.8h, v28.8b, v7.8b\n"
-    "smlal v18.4s, v23.4h, v1.4h\n"
-    "ldr x22, [x25, #0x100]\n"
+    "ldr d30, [x15, x24]\n"
+    "ssubl v30.8h, v30.8b, v9.8b\n"
+    "smlal v17.4s, v27.4h, v1.4h\n"
+    "smlal v10.4s, v28.4h, v1.4h\n"
+    "ldr d0, [x23, #0x28]\n"
+    "ssubl v0.8h, v0.8b, v14.8b\n"
+    "smlal v6.4s, v23.4h, v1.4h\n"
+    "smlal2 v8.4s, v27.8h, v1.8h\n"
+    "ldr x12, [x20, #0x80]\n"
+    "ldr x26, [x20, #0x88]\n"
+    "smlal2 v7.4s, v28.8h, v1.8h\n"
+    "smlal v15.4s, v27.4h, v2.4h\n"
+    "ldr x14, [x20, #0x90]\n"
+    "ldr x15, [x20, #0x98]\n"
+    "smlal2 v16.4s, v27.8h, v2.8h\n"
     "smlal2 v5.4s, v23.8h, v1.8h\n"
-    "smlal v11.4s, v25.4h, v1.4h\n"
-    "smlal2 v8.4s, v25.8h, v1.8h\n"
-    "smlal v10.4s, v24.4h, v1.4h\n"
-    "smlal2 v9.4s, v24.8h, v1.8h\n"
-    "ldr d1, [x3, #0x58]\n"
-    "ssubl v1.8h, v1.8b, v13.8b\n"
-    "smlal v15.4s, v23.4h, v2.4h\n"
-    "smlal2 v20.4s, v23.8h, v2.8h\n"
-    "ldr d23, [x20, x10]\n"
-    "ssubl v23.8h, v23.8b, v7.8b\n"
-    "smlal v18.4s, v31.4h, v2.4h\n"
-    "ldr x20, [x25, #0x108]\n"
-    "smlal2 v5.4s, v31.8h, v2.8h\n"
-    "smlal v11.4s, v24.4h, v2.4h\n"
-    "smlal2 v8.4s, v24.8h, v2.8h\n"
-    "smlal v10.4s, v27.4h, v2.4h\n"
-    "smlal2 v9.4s, v27.8h, v2.8h\n"
-    "ldr d2, [x3, #0x60]\n"
-    "ssubl v2.8h, v2.8b, v13.8b\n"
-    "smlal v15.4s, v31.4h, v3.4h\n"
-    "smlal2 v20.4s, v31.8h, v3.8h\n"
-    "ldr d31, [x13, x10]\n"
-    "ssubl v31.8h, v31.8b, v7.8b\n"
-    "smlal v18.4s, v30.4h, v3.4h\n"
-    "ldr x13, [x25, #0x110]\n"
-    "smlal2 v5.4s, v30.8h, v3.8h\n"
-    "smlal v11.4s, v27.4h, v3.4h\n"
-    "smlal2 v8.4s, v27.8h, v3.8h\n"
-    "smlal v10.4s, v23.4h, v3.4h\n"
-    "smlal2 v9.4s, v23.8h, v3.8h\n"
-    "ldr d3, [x3, #0x68]\n"
-    "ssubl v3.8h, v3.8b, v13.8b\n"
-    "smlal v15.4s, v30.4h, v4.4h\n"
-    "smlal2 v20.4s, v30.8h, v4.8h\n"
-    "ldr d30, [x21, x10]\n"
-    "ssubl v30.8h, v30.8b, v7.8b\n"
-    "smlal v18.4s, v26.4h, v4.4h\n"
-    "ldr x21, [x25, #0x118]\n"
-    "smlal2 v5.4s, v26.8h, v4.8h\n"
-    "ldr d26, [x14, x10]\n"
-    "ssubl v26.8h, v26.8b, v7.8b\n"
-    "smlal v11.4s, v23.4h, v4.4h\n"
-    "smlal2 v8.4s, v23.8h, v4.8h\n"
-    "smlal v10.4s, v28.4h, v4.4h\n"
-    "smlal2 v9.4s, v28.8h, v4.8h\n"
-    "ldr d4, [x3, #0x70]\n"
-    "ssubl v4.8h, v4.8b, v13.8b\n"
-    "smlal v15.4s, v22.4h, v0.4h\n"
-    "smlal2 v20.4s, v22.8h, v0.8h\n"
-    "ldr d22, [x0, x10]\n"
-    "ssubl v22.8h, v22.8b, v7.8b\n"
-    "smlal v18.4s, v25.4h, v0.4h\n"
-    "smlal2 v5.4s, v25.8h, v0.8h\n"
-    "smlal v11.4s, v31.4h, v0.4h\n"
-    "smlal2 v8.4s, v31.8h, v0.8h\n"
-    "smlal v10.4s, v30.4h, v0.4h\n"
-    "smlal2 v9.4s, v30.8h, v0.8h\n"
-    "ldr d0, [x3, #0x78]\n"
-    "ssubl v0.8h, v0.8b, v13.8b\n"
-    "smlal v15.4s, v25.4h, v1.4h\n"
-    "smlal2 v20.4s, v25.8h, v1.8h\n"
-    "ldr d25, [x11, x10]\n"
-    "ssubl v25.8h, v25.8b, v7.8b\n"
-    "smlal v18.4s, v24.4h, v1.4h\n"
-    "smlal2 v5.4s, v24.8h, v1.8h\n"
-    "smlal v11.4s, v30.4h, v1.4h\n"
-    "smlal2 v8.4s, v30.8h, v1.8h\n"
-    "smlal v10.4s, v26.4h, v1.4h\n"
-    "smlal2 v9.4s, v26.8h, v1.8h\n"
-    "ldr d1, [x3, #0x80]\n"
-    "ssubl v1.8h, v1.8b, v13.8b\n"
-    "smlal v15.4s, v24.4h, v2.4h\n"
-    "smlal2 v20.4s, v24.8h, v2.8h\n"
-    "ldr d24, [x24, x10]\n"
-    "ssubl v24.8h, v24.8b, v7.8b\n"
-    "smlal v18.4s, v27.4h, v2.4h\n"
-    "smlal2 v5.4s, v27.8h, v2.8h\n"
-    "smlal v11.4s, v26.4h, v2.4h\n"
-    "smlal2 v8.4s, v26.8h, v2.8h\n"
-    "smlal v10.4s, v25.4h, v2.4h\n"
-    "smlal2 v9.4s, v25.8h, v2.8h\n"
-    "ldr d2, [x3, #0x88]\n"
-    "ssubl v2.8h, v2.8b, v13.8b\n"
-    "smlal v15.4s, v27.4h, v3.4h\n"
-    "smlal2 v20.4s, v27.8h, v3.8h\n"
-    "ldr d27, [x15, x10]\n"
-    "ssubl v27.8h, v27.8b, v7.8b\n"
-    "smlal v18.4s, v23.4h, v3.4h\n"
-    "smlal2 v5.4s, v23.8h, v3.8h\n"
-    "smlal v11.4s, v25.4h, v3.4h\n"
-    "smlal2 v8.4s, v25.8h, v3.8h\n"
-    "smlal v10.4s, v24.4h, v3.4h\n"
-    "smlal2 v9.4s, v24.8h, v3.8h\n"
-    "ldr d3, [x3, #0x90]\n"
-    "ssubl v3.8h, v3.8b, v13.8b\n"
-    "smlal v15.4s, v23.4h, v4.4h\n"
-    "smlal2 v20.4s, v23.8h, v4.8h\n"
-    "ldr d23, [x9, x10]\n"
-    "ssubl v23.8h, v23.8b, v7.8b\n"
-    "smlal v18.4s, v28.4h, v4.4h\n"
-    "smlal2 v5.4s, v28.8h, v4.8h\n"
-    "ldr d28, [x12, x10]\n"
-    "ssubl v28.8h, v28.8b, v7.8b\n"
-    "smlal v11.4s, v24.4h, v4.4h\n"
-    "smlal2 v8.4s, v24.8h, v4.8h\n"
-    "smlal v10.4s, v22.4h, v4.4h\n"
-    "smlal2 v9.4s, v22.8h, v4.8h\n"
-    "ldr d4, [x3, #0x98]\n"
-    "ssubl v4.8h, v4.8b, v13.8b\n"
-    "smlal v15.4s, v31.4h, v0.4h\n"
-    "smlal2 v20.4s, v31.8h, v0.8h\n"
-    "ldr d31, [x27, x10]\n"
-    "ssubl v31.8h, v31.8b, v7.8b\n"
-    "smlal v18.4s, v30.4h, v0.4h\n"
-    "smlal2 v5.4s, v30.8h, v0.8h\n"
-    "smlal v11.4s, v27.4h, v0.4h\n"
-    "smlal2 v8.4s, v27.8h, v0.8h\n"
-    "smlal v10.4s, v23.4h, v0.4h\n"
-    "smlal2 v9.4s, v23.8h, v0.8h\n"
-    "ldr d0, [x3, #0xa0]\n"
-    "ssubl v0.8h, v0.8b, v13.8b\n"
-    "smlal v15.4s, v30.4h, v1.4h\n"
-    "smlal2 v20.4s, v30.8h, v1.8h\n"
-    "ldr d30, [x28, x10]\n"
-    "ssubl v30.8h, v30.8b, v7.8b\n"
-    "smlal v18.4s, v26.4h, v1.4h\n"
-    "smlal2 v5.4s, v26.8h, v1.8h\n"
-    "smlal v11.4s, v23.4h, v1.4h\n"
-    "smlal2 v8.4s, v23.8h, v1.8h\n"
-    "smlal v10.4s, v31.4h, v1.4h\n"
-    "smlal2 v9.4s, v31.8h, v1.8h\n"
-    "ldr d1, [x3, #0xa8]\n"
-    "ssubl v1.8h, v1.8b, v13.8b\n"
-    "smlal v15.4s, v26.4h, v2.4h\n"
-    "smlal2 v20.4s, v26.8h, v2.8h\n"
-    "ldr d26, [x7, x10]\n"
-    "ssubl v26.8h, v26.8b, v7.8b\n"
-    "smlal v18.4s, v25.4h, v2.4h\n"
-    "smlal2 v5.4s, v25.8h, v2.8h\n"
-    "smlal v11.4s, v31.4h, v2.4h\n"
-    "smlal2 v8.4s, v31.8h, v2.8h\n"
-    "smlal v10.4s, v30.4h, v2.4h\n"
-    "smlal2 v9.4s, v30.8h, v2.8h\n"
-    "ldr d2, [x3, #0xb0]\n"
-    "ssubl v2.8h, v2.8b, v13.8b\n"
+    "ldr d27, [x19, x24]\n"
+    "ssubl v27.8h, v27.8b, v9.8b\n"
+    "smlal v17.4s, v25.4h, v2.4h\n"
+    "smlal v10.4s, v23.4h, v2.4h\n"
+    "ldr d1, [x23, #0x30]\n"
+    "ssubl v1.8h, v1.8b, v14.8b\n"
+    "smlal v6.4s, v31.4h, v2.4h\n"
+    "smlal2 v8.4s, v25.8h, v2.8h\n"
+    "ldr x21, [x20, #0xa0]\n"
+    "ldr x2, [x20, #0xa8]\n"
+    "smlal2 v7.4s, v23.8h, v2.8h\n"
     "smlal v15.4s, v25.4h, v3.4h\n"
-    "smlal2 v20.4s, v25.8h, v3.8h\n"
-    "ldr d25, [x26, x10]\n"
-    "ssubl v25.8h, v25.8b, v7.8b\n"
-    "smlal v18.4s, v24.4h, v3.4h\n"
-    "smlal2 v5.4s, v24.8h, v3.8h\n"
-    "smlal v11.4s, v30.4h, v3.4h\n"
-    "smlal2 v8.4s, v30.8h, v3.8h\n"
-    "smlal v10.4s, v28.4h, v3.4h\n"
-    "smlal2 v9.4s, v28.8h, v3.8h\n"
-    "ldr d3, [x3, #0xb8]\n"
-    "ssubl v3.8h, v3.8b, v13.8b\n"
+    "ldr x13, [x20, #0xb0]\n"
+    "ldr x9, [x20, #0xb8]\n"
+    "smlal2 v16.4s, v25.8h, v3.8h\n"
+    "smlal2 v5.4s, v31.8h, v2.8h\n"
+    "ldr d25, [x27, x24]\n"
+    "ssubl v25.8h, v25.8b, v9.8b\n"
+    "smlal v17.4s, v24.4h, v3.4h\n"
+    "smlal v10.4s, v31.4h, v3.4h\n"
+    "ldr d2, [x23, #0x38]\n"
+    "ssubl v2.8h, v2.8b, v14.8b\n"
+    "smlal v6.4s, v30.4h, v3.4h\n"
+    "smlal2 v8.4s, v24.8h, v3.8h\n"
+    "ldr x19, [x20, #0xc0]\n"
+    "ldr x28, [x20, #0xc8]\n"
+    "smlal2 v7.4s, v31.8h, v3.8h\n"
     "smlal v15.4s, v24.4h, v4.4h\n"
-    "smlal2 v20.4s, v24.8h, v4.8h\n"
-    "ldr d24, [x23, x10]\n"
-    "ssubl v24.8h, v24.8b, v7.8b\n"
-    "smlal v18.4s, v22.4h, v4.4h\n"
-    "smlal2 v5.4s, v22.8h, v4.8h\n"
-    "smlal v11.4s, v28.4h, v4.4h\n"
-    "smlal2 v8.4s, v28.8h, v4.8h\n"
-    "smlal v10.4s, v26.4h, v4.4h\n"
-    "smlal2 v9.4s, v26.8h, v4.8h\n"
-    "ldr d4, [x3, #0xc0]\n"
-    "ssubl v4.8h, v4.8b, v13.8b\n"
-    "smlal v15.4s, v27.4h, v0.4h\n"
-    "smlal2 v20.4s, v27.8h, v0.8h\n"
-    "ldr d27, [x22, x10]\n"
-    "ssubl v27.8h, v27.8b, v7.8b\n"
-    "smlal v18.4s, v23.4h, v0.4h\n"
-    "smlal2 v5.4s, v23.8h, v0.8h\n"
-    "smlal v11.4s, v25.4h, v0.4h\n"
-    "smlal2 v8.4s, v25.8h, v0.8h\n"
-    "ldr d25, [x20, x10]\n"
-    "ssubl v25.8h, v25.8b, v7.8b\n"
-    "smlal v10.4s, v24.4h, v0.4h\n"
-    "smlal2 v9.4s, v24.8h, v0.8h\n"
-    "smlal v15.4s, v23.4h, v1.4h\n"
-    "smlal2 v20.4s, v23.8h, v1.8h\n"
-    "smlal v18.4s, v31.4h, v1.4h\n"
-    "smlal2 v5.4s, v31.8h, v1.8h\n"
-    "smlal v11.4s, v24.4h, v1.4h\n"
-    "smlal2 v8.4s, v24.8h, v1.8h\n"
-    "ldr d24, [x13, x10]\n"
-    "ssubl v24.8h, v24.8b, v7.8b\n"
-    "smlal v10.4s, v27.4h, v1.4h\n"
-    "smlal2 v9.4s, v27.8h, v1.8h\n"
-    "smlal v15.4s, v31.4h, v2.4h\n"
-    "smlal2 v20.4s, v31.8h, v2.8h\n"
-    "smlal v18.4s, v30.4h, v2.4h\n"
-    "smlal2 v5.4s, v30.8h, v2.8h\n"
-    "smlal v11.4s, v27.4h, v2.4h\n"
-    "smlal2 v8.4s, v27.8h, v2.8h\n"
-    "ldr d27, [x21, x10]\n"
-    "add x10, x10, #0x8\n"
-    "smlal v10.4s, v25.4h, v2.4h\n"
-    "ssubl v27.8h, v27.8b, v7.8b\n"
-    "smlal2 v9.4s, v25.8h, v2.8h\n"
-    "smlal v15.4s, v30.4h, v3.4h\n"
-    "smlal2 v20.4s, v30.8h, v3.8h\n"
-    "smlal v18.4s, v28.4h, v3.4h\n"
-    "smlal2 v5.4s, v28.8h, v3.8h\n"
-    "smlal v11.4s, v25.4h, v3.4h\n"
-    "smlal2 v8.4s, v25.8h, v3.8h\n"
-    "smlal v10.4s, v24.4h, v3.4h\n"
-    "smlal2 v9.4s, v24.8h, v3.8h\n"
-    "smlal v15.4s, v28.4h, v4.4h\n"
-    "smlal2 v20.4s, v28.8h, v4.8h\n"
-    "smlal v18.4s, v26.4h, v4.4h\n"
+    "ldr x6, [x20, #0xd0]\n"
+    "ldr x27, [x20, #0xd8]\n"
+    "smlal2 v16.4s, v24.8h, v4.8h\n"
+    "smlal2 v5.4s, v30.8h, v3.8h\n"
+    "ldr d24, [x5, x24]\n"
+    "ssubl v24.8h, v24.8b, v9.8b\n"
+    "smlal v17.4s, v27.4h, v4.4h\n"
+    "smlal v10.4s, v30.4h, v4.4h\n"
+    "ldr d3, [x23, #0x40]\n"
+    "ssubl v3.8h, v3.8b, v14.8b\n"
+    "smlal v6.4s, v26.4h, v4.4h\n"
+    "smlal2 v8.4s, v27.8h, v4.8h\n"
+    "ldr d27, [x11, x24]\n"
+    "ssubl v27.8h, v27.8b, v9.8b\n"
+    "smlal2 v7.4s, v30.8h, v4.8h\n"
+    "smlal v15.4s, v29.4h, v0.4h\n"
+    "ldr x11, [x20, #0xe0]\n"
+    "ldr x17, [x20, #0xe8]\n"
+    "smlal2 v16.4s, v29.8h, v0.8h\n"
     "smlal2 v5.4s, v26.8h, v4.8h\n"
-    "smlal v11.4s, v24.4h, v4.4h\n"
-    "smlal2 v8.4s, v24.8h, v4.8h\n"
-    "smlal v10.4s, v27.4h, v4.4h\n"
-    "smlal2 v9.4s, v27.8h, v4.8h\n"
-    "sqrdmulh v15.4s, v15.4s, v6.4s\n"
-    "sqrdmulh v20.4s, v20.4s, v17.4s\n"
-    "sqrdmulh v18.4s, v18.4s, v6.4s\n"
-    "sqrdmulh v5.4s, v5.4s, v17.4s\n"
-    "and v1.16b, v15.16b, v21.16b\n"
-    "sshr v1.4s, v1.4s, #0x1f\n"
-    "and v29.16b, v20.16b, v14.16b\n"
-    "and v3.16b, v18.16b, v21.16b\n"
-    "sshr v29.4s, v29.4s, #0x1f\n"
-    "and v2.16b, v5.16b, v14.16b\n"
-    "sqrdmulh v11.4s, v11.4s, v6.4s\n"
-    "sshr v3.4s, v3.4s, #0x1f\n"
-    "sqrdmulh v8.4s, v8.4s, v17.4s\n"
+    "ldr d4, [x23, #0x48]\n"
+    "ssubl v4.8h, v4.8b, v14.8b\n"
+    "smlal v17.4s, v28.4h, v0.4h\n"
+    "smlal v10.4s, v22.4h, v0.4h\n"
+    "ldr x5, [x20, #0xf0]\n"
+    "ldr x25, [x20, #0xf8]\n"
+    "smlal v6.4s, v25.4h, v0.4h\n"
+    "smlal2 v8.4s, v28.8h, v0.8h\n"
+    "ldr q12, [x10, #0x0]\n"
+    "ldr q19, [x1, #0x0]\n"
+    "smlal2 v7.4s, v22.8h, v0.8h\n"
+    "smlal v15.4s, v28.4h, v1.4h\n"
+    "ldr q20, [x10, #0x10]\n"
+    "ldr q29, [x1, #0x10]\n"
+    "smlal2 v16.4s, v28.8h, v1.8h\n"
+    "smlal2 v5.4s, v25.8h, v0.8h\n"
+    "ldr d28, [x26, x24]\n"
+    "ldr d0, [x23, #0x50]\n"
+    "smlal v17.4s, v23.4h, v1.4h\n"
+    "smlal v10.4s, v25.4h, v1.4h\n"
+    "ssubl v28.8h, v28.8b, v9.8b\n"
+    "ldr x26, [x20, #0x100]\n"
+    "smlal v6.4s, v24.4h, v1.4h\n"
+    "smlal2 v8.4s, v23.8h, v1.8h\n"
+    "ssubl v0.8h, v0.8b, v14.8b\n"
+    "tst x0, #0x7\n"
+    "smlal2 v7.4s, v25.8h, v1.8h\n"
+    "smlal v15.4s, v23.4h, v2.4h\n"
+    "add x10, x10, #0x20\n"
+    "add x1, x1, #0x20\n"
+    "smlal2 v16.4s, v23.8h, v2.8h\n"
+    "ldr d23, [x12, x24]\n"
+    "smlal2 v5.4s, v24.8h, v1.8h\n"
+    "ssubl v23.8h, v23.8b, v9.8b\n"
+    "smlal v17.4s, v31.4h, v2.4h\n"
+    "smlal v10.4s, v24.4h, v2.4h\n"
+    "ldr d1, [x23, #0x58]\n"
+    "ssubl v1.8h, v1.8b, v14.8b\n"
+    "smlal v6.4s, v27.4h, v2.4h\n"
+    "smlal2 v8.4s, v31.8h, v2.8h\n"
+    "ldr x12, [x20, #0x108]\n"
+    "smlal2 v7.4s, v24.8h, v2.8h\n"
+    "smlal v15.4s, v31.4h, v3.4h\n"
+    "smlal2 v16.4s, v31.8h, v3.8h\n"
+    "smlal2 v5.4s, v27.8h, v2.8h\n"
+    "ldr d31, [x14, x24]\n"
+    "ssubl v31.8h, v31.8b, v9.8b\n"
+    "smlal v17.4s, v30.4h, v3.4h\n"
+    "smlal v10.4s, v27.4h, v3.4h\n"
+    "ldr d2, [x23, #0x60]\n"
+    "ssubl v2.8h, v2.8b, v14.8b\n"
+    "smlal v6.4s, v23.4h, v3.4h\n"
+    "smlal2 v8.4s, v30.8h, v3.8h\n"
+    "ldr x14, [x20, #0x110]\n"
+    "smlal2 v7.4s, v27.8h, v3.8h\n"
+    "smlal v15.4s, v30.4h, v4.4h\n"
+    "smlal2 v16.4s, v30.8h, v4.8h\n"
+    "ldr d30, [x15, x24]\n"
+    "smlal2 v5.4s, v23.8h, v3.8h\n"
+    "ssubl v30.8h, v30.8b, v9.8b\n"
+    "smlal v17.4s, v26.4h, v4.4h\n"
+    "smlal v10.4s, v23.4h, v4.4h\n"
+    "ldr d3, [x23, #0x68]\n"
+    "ssubl v3.8h, v3.8b, v14.8b\n"
+    "smlal v6.4s, v28.4h, v4.4h\n"
+    "smlal2 v8.4s, v26.8h, v4.8h\n"
+    "ldr d26, [x21, x24]\n"
+    "ssubl v26.8h, v26.8b, v9.8b\n"
+    "smlal2 v7.4s, v23.8h, v4.8h\n"
+    "smlal v15.4s, v22.4h, v0.4h\n"
+    "ldr x21, [x20, #0x118]\n"
+    "smlal2 v16.4s, v22.8h, v0.8h\n"
+    "smlal2 v5.4s, v28.8h, v4.8h\n"
+    "ldr d4, [x23, #0x70]\n"
+    "ldr d22, [x9, x24]\n"
+    "smlal v17.4s, v25.4h, v0.4h\n"
+    "smlal v10.4s, v31.4h, v0.4h\n"
+    "ssubl v4.8h, v4.8b, v14.8b\n"
+    "smlal v6.4s, v30.4h, v0.4h\n"
+    "smlal2 v8.4s, v25.8h, v0.8h\n"
+    "ssubl v22.8h, v22.8b, v9.8b\n"
+    "smlal2 v7.4s, v31.8h, v0.8h\n"
+    "smlal v15.4s, v25.4h, v1.4h\n"
+    "smlal2 v16.4s, v25.8h, v1.8h\n"
+    "ldr d25, [x2, x24]\n"
+    "smlal2 v5.4s, v30.8h, v0.8h\n"
+    "ssubl v25.8h, v25.8b, v9.8b\n"
+    "smlal v17.4s, v24.4h, v1.4h\n"
+    "smlal v10.4s, v30.4h, v1.4h\n"
+    "ldr d0, [x23, #0x78]\n"
+    "ssubl v0.8h, v0.8b, v14.8b\n"
+    "smlal v6.4s, v26.4h, v1.4h\n"
+    "smlal2 v8.4s, v24.8h, v1.8h\n"
+    "smlal2 v7.4s, v30.8h, v1.8h\n"
+    "smlal v15.4s, v24.4h, v2.4h\n"
+    "smlal2 v16.4s, v24.8h, v2.8h\n"
+    "ldr d24, [x13, x24]\n"
+    "smlal2 v5.4s, v26.8h, v1.8h\n"
+    "ssubl v24.8h, v24.8b, v9.8b\n"
+    "smlal v17.4s, v27.4h, v2.4h\n"
+    "smlal v10.4s, v26.4h, v2.4h\n"
+    "ldr d1, [x23, #0x80]\n"
+    "ssubl v1.8h, v1.8b, v14.8b\n"
+    "smlal v6.4s, v25.4h, v2.4h\n"
+    "smlal2 v8.4s, v27.8h, v2.8h\n"
+    "smlal2 v7.4s, v26.8h, v2.8h\n"
+    "smlal v15.4s, v27.4h, v3.4h\n"
+    "smlal2 v16.4s, v27.8h, v3.8h\n"
+    "smlal2 v5.4s, v25.8h, v2.8h\n"
+    "ldr d27, [x19, x24]\n"
+    "ssubl v27.8h, v27.8b, v9.8b\n"
+    "smlal v17.4s, v23.4h, v3.4h\n"
+    "smlal v10.4s, v25.4h, v3.4h\n"
+    "ldr d2, [x23, #0x88]\n"
+    "ssubl v2.8h, v2.8b, v14.8b\n"
+    "smlal v6.4s, v24.4h, v3.4h\n"
+    "smlal2 v8.4s, v23.8h, v3.8h\n"
+    "smlal2 v7.4s, v25.8h, v3.8h\n"
+    "smlal v15.4s, v23.4h, v4.4h\n"
+    "smlal2 v16.4s, v23.8h, v4.8h\n"
+    "ldr d23, [x28, x24]\n"
+    "smlal2 v5.4s, v24.8h, v3.8h\n"
+    "ssubl v23.8h, v23.8b, v9.8b\n"
+    "smlal v17.4s, v28.4h, v4.4h\n"
+    "smlal v10.4s, v24.4h, v4.4h\n"
+    "ldr d3, [x23, #0x90]\n"
+    "ssubl v3.8h, v3.8b, v14.8b\n"
+    "smlal v6.4s, v22.4h, v4.4h\n"
+    "smlal2 v8.4s, v28.8h, v4.8h\n"
+    "ldr d28, [x11, x24]\n"
+    "ssubl v28.8h, v28.8b, v9.8b\n"
+    "smlal2 v7.4s, v24.8h, v4.8h\n"
+    "smlal v15.4s, v31.4h, v0.4h\n"
+    "smlal2 v16.4s, v31.8h, v0.8h\n"
+    "ldr d31, [x6, x24]\n"
+    "smlal2 v5.4s, v22.8h, v4.8h\n"
+    "ssubl v31.8h, v31.8b, v9.8b\n"
+    "smlal v17.4s, v30.4h, v0.4h\n"
+    "smlal v10.4s, v27.4h, v0.4h\n"
+    "ldr d4, [x23, #0x98]\n"
+    "ssubl v4.8h, v4.8b, v14.8b\n"
+    "smlal v6.4s, v23.4h, v0.4h\n"
+    "smlal2 v8.4s, v30.8h, v0.8h\n"
+    "smlal2 v7.4s, v27.8h, v0.8h\n"
+    "smlal v15.4s, v30.4h, v1.4h\n"
+    "smlal2 v16.4s, v30.8h, v1.8h\n"
+    "ldr d30, [x27, x24]\n"
+    "smlal2 v5.4s, v23.8h, v0.8h\n"
+    "ssubl v30.8h, v30.8b, v9.8b\n"
+    "smlal v17.4s, v26.4h, v1.4h\n"
+    "smlal v10.4s, v23.4h, v1.4h\n"
+    "ldr d0, [x23, #0xa0]\n"
+    "ssubl v0.8h, v0.8b, v14.8b\n"
+    "smlal v6.4s, v31.4h, v1.4h\n"
+    "smlal2 v8.4s, v26.8h, v1.8h\n"
+    "smlal2 v7.4s, v23.8h, v1.8h\n"
+    "smlal v15.4s, v26.4h, v2.4h\n"
+    "smlal2 v16.4s, v26.8h, v2.8h\n"
+    "smlal2 v5.4s, v31.8h, v1.8h\n"
+    "ldr d26, [x17, x24]\n"
+    "ssubl v26.8h, v26.8b, v9.8b\n"
+    "smlal v17.4s, v25.4h, v2.4h\n"
+    "smlal v10.4s, v31.4h, v2.4h\n"
+    "ldr d1, [x23, #0xa8]\n"
+    "ssubl v1.8h, v1.8b, v14.8b\n"
+    "smlal v6.4s, v30.4h, v2.4h\n"
+    "smlal2 v8.4s, v25.8h, v2.8h\n"
+    "smlal2 v7.4s, v31.8h, v2.8h\n"
+    "smlal v15.4s, v25.4h, v3.4h\n"
+    "smlal2 v16.4s, v25.8h, v3.8h\n"
+    "smlal2 v5.4s, v30.8h, v2.8h\n"
+    "ldr d25, [x5, x24]\n"
+    "ssubl v25.8h, v25.8b, v9.8b\n"
+    "smlal v17.4s, v24.4h, v3.4h\n"
+    "smlal v10.4s, v30.4h, v3.4h\n"
+    "ldr d2, [x23, #0xb0]\n"
+    "ssubl v2.8h, v2.8b, v14.8b\n"
+    "smlal v6.4s, v28.4h, v3.4h\n"
+    "smlal2 v8.4s, v24.8h, v3.8h\n"
+    "smlal2 v7.4s, v30.8h, v3.8h\n"
+    "smlal v15.4s, v24.4h, v4.4h\n"
+    "smlal2 v16.4s, v24.8h, v4.8h\n"
+    "ldr d24, [x25, x24]\n"
+    "smlal2 v5.4s, v28.8h, v3.8h\n"
+    "ssubl v24.8h, v24.8b, v9.8b\n"
+    "smlal v17.4s, v22.4h, v4.4h\n"
+    "smlal v10.4s, v28.4h, v4.4h\n"
+    "ldr d3, [x23, #0xb8]\n"
+    "ssubl v3.8h, v3.8b, v14.8b\n"
+    "smlal v6.4s, v26.4h, v4.4h\n"
+    "smlal2 v7.4s, v28.8h, v4.8h\n"
+    "smlal v15.4s, v27.4h, v0.4h\n"
+    "smlal2 v16.4s, v27.8h, v0.8h\n"
+    "ldr d27, [x26, x24]\n"
+    "ssubl v27.8h, v27.8b, v9.8b\n"
+    "smlal2 v8.4s, v22.8h, v4.8h\n"
+    "smlal2 v5.4s, v26.8h, v4.8h\n"
+    "ldr d4, [x23, #0xc0]\n"
+    "ssubl v4.8h, v4.8b, v14.8b\n"
+    "smlal v17.4s, v23.4h, v0.4h\n"
+    "smlal v10.4s, v25.4h, v0.4h\n"
+    "smlal v6.4s, v24.4h, v0.4h\n"
+    "smlal2 v7.4s, v25.8h, v0.8h\n"
+    "ldr d25, [x12, x24]\n"
+    "ssubl v25.8h, v25.8b, v9.8b\n"
+    "smlal2 v8.4s, v23.8h, v0.8h\n"
+    "smlal2 v5.4s, v24.8h, v0.8h\n"
+    "smlal v15.4s, v23.4h, v1.4h\n"
+    "smlal v17.4s, v31.4h, v1.4h\n"
+    "smlal v10.4s, v24.4h, v1.4h\n"
+    "smlal v6.4s, v27.4h, v1.4h\n"
+    "smlal2 v7.4s, v24.8h, v1.8h\n"
+    "ldr d24, [x14, x24]\n"
+    "smlal2 v16.4s, v23.8h, v1.8h\n"
+    "ssubl v24.8h, v24.8b, v9.8b\n"
+    "smlal2 v8.4s, v31.8h, v1.8h\n"
+    "smlal2 v5.4s, v27.8h, v1.8h\n"
+    "smlal v15.4s, v31.4h, v2.4h\n"
+    "smlal v17.4s, v30.4h, v2.4h\n"
+    "smlal v10.4s, v27.4h, v2.4h\n"
+    "smlal v6.4s, v25.4h, v2.4h\n"
+    "smlal2 v7.4s, v27.8h, v2.8h\n"
+    "ldr d27, [x21, x24]\n"
+    "smlal2 v16.4s, v31.8h, v2.8h\n"
+    "ssubl v27.8h, v27.8b, v9.8b\n"
+    "smlal2 v8.4s, v30.8h, v2.8h\n"
+    "smlal2 v5.4s, v25.8h, v2.8h\n"
+    "add x24, x24, #0x8\n"
+    "smlal v15.4s, v30.4h, v3.4h\n"
+    "smlal v17.4s, v28.4h, v3.4h\n"
+    "smlal v10.4s, v25.4h, v3.4h\n"
+    "smlal v6.4s, v24.4h, v3.4h\n"
+    "smlal2 v16.4s, v30.8h, v3.8h\n"
+    "smlal2 v8.4s, v28.8h, v3.8h\n"
+    "smlal2 v7.4s, v25.8h, v3.8h\n"
+    "smlal2 v5.4s, v24.8h, v3.8h\n"
+    "smlal v15.4s, v28.4h, v4.4h\n"
+    "smlal v17.4s, v26.4h, v4.4h\n"
+    "sqdmulh v15.4s, v15.4s, v12.4s\n"
+    "smlal v10.4s, v24.4h, v4.4h\n"
+    "smlal v6.4s, v27.4h, v4.4h\n"
+    "sqdmulh v17.4s, v17.4s, v12.4s\n"
+    "smlal2 v16.4s, v28.8h, v4.8h\n"
+    "smlal2 v8.4s, v26.8h, v4.8h\n"
+    "sqdmulh v10.4s, v10.4s, v12.4s\n"
+    "smlal2 v7.4s, v24.8h, v4.8h\n"
+    "smlal2 v5.4s, v27.8h, v4.8h\n"
+    "sqdmulh v6.4s, v6.4s, v12.4s\n"
+    "and v23.16b, v15.16b, v19.16b\n"
+    "sqdmulh v16.4s, v16.4s, v20.4s\n"
+    "and v22.16b, v17.16b, v19.16b\n"
+    "sqdmulh v8.4s, v8.4s, v20.4s\n"
+    "and v21.16b, v10.16b, v19.16b\n"
+    "sqdmulh v7.4s, v7.4s, v20.4s\n"
+    "and v26.16b, v6.16b, v19.16b\n"
+    "sqdmulh v5.4s, v5.4s, v20.4s\n"
+    "sshr v23.4s, v23.4s, #0x1f\n"
+    "and v4.16b, v16.16b, v29.16b\n"
+    "sshr v22.4s, v22.4s, #0x1f\n"
+    "and v2.16b, v8.16b, v29.16b\n"
+    "sshr v21.4s, v21.4s, #0x1f\n"
+    "and v3.16b, v7.16b, v29.16b\n"
+    "sshr v26.4s, v26.4s, #0x1f\n"
+    "and v25.16b, v5.16b, v29.16b\n"
+    "sqadd v15.4s, v15.4s, v23.4s\n"
+    "sshr v4.4s, v4.4s, #0x1f\n"
+    "sqadd v17.4s, v17.4s, v22.4s\n"
     "sshr v2.4s, v2.4s, #0x1f\n"
-    "sqadd v15.4s, v15.4s, v1.4s\n"
-    "sqrdmulh v10.4s, v10.4s, v6.4s\n"
-    "and v0.16b, v11.16b, v21.16b\n"
-    "sshr v0.4s, v0.4s, #0x1f\n"
-    "srshl v15.4s, v15.4s, v21.4s\n"
-    "sqadd v20.4s, v20.4s, v29.4s\n"
-    "sqadd v18.4s, v18.4s, v3.4s\n"
-    "sqadd v5.4s, v5.4s, v2.4s\n"
-    "and v27.16b, v8.16b, v14.16b\n"
-    "sshr v27.4s, v27.4s, #0x1f\n"
-    "add v15.4s, v15.4s, v19.4s\n"
-    "srshl v20.4s, v20.4s, v14.4s\n"
-    "srshl v18.4s, v18.4s, v21.4s\n"
-    "srshl v5.4s, v5.4s, v14.4s\n"
-    "smin v15.4s, v15.4s, v12.4s\n"
-    "add v20.4s, v20.4s, v19.4s\n"
-    "add v18.4s, v18.4s, v19.4s\n"
-    "smax v15.4s, v15.4s, v16.4s\n"
-    "smin v20.4s, v20.4s, v12.4s\n"
-    "smin v18.4s, v18.4s, v12.4s\n"
-    "add v5.4s, v5.4s, v19.4s\n"
-    "smax v20.4s, v20.4s, v16.4s\n"
-    "smax v18.4s, v18.4s, v16.4s\n"
-    "smin v5.4s, v5.4s, v12.4s\n"
-    "uzp1 v15.16b, v15.16b, v20.16b\n"
-    "sqadd v11.4s, v11.4s, v0.4s\n"
+    "sqadd v10.4s, v10.4s, v21.4s\n"
+    "sshr v3.4s, v3.4s, #0x1f\n"
+    "sqadd v6.4s, v6.4s, v26.4s\n"
+    "sshr v25.4s, v25.4s, #0x1f\n"
+    "srshl v15.4s, v15.4s, v19.4s\n"
+    "sqadd v16.4s, v16.4s, v4.4s\n"
+    "srshl v17.4s, v17.4s, v19.4s\n"
+    "sqadd v8.4s, v8.4s, v2.4s\n"
+    "srshl v10.4s, v10.4s, v19.4s\n"
+    "sqadd v7.4s, v7.4s, v3.4s\n"
+    "srshl v6.4s, v6.4s, v19.4s\n"
+    "sqadd v5.4s, v5.4s, v25.4s\n"
+    "srshl v16.4s, v16.4s, v29.4s\n"
+    "sqxtn v15.4h, v15.4s\n"
+    "srshl v8.4s, v8.4s, v29.4s\n"
+    "sqxtn v17.4h, v17.4s\n"
+    "srshl v7.4s, v7.4s, v29.4s\n"
+    "sqxtn v10.4h, v10.4s\n"
+    "srshl v5.4s, v5.4s, v29.4s\n"
+    "sqxtn v6.4h, v6.4s\n"
+    "sqxtn2 v15.8h, v16.4s\n"
+    "sqxtn2 v17.8h, v8.4s\n"
+    "sqxtn2 v10.8h, v7.4s\n"
+    "sqxtn2 v6.8h, v5.4s\n"
+    "sqadd v15.8h, v15.8h, v18.8h\n"
+    "sqadd v17.8h, v17.8h, v18.8h\n"
+    "sqadd v10.8h, v10.8h, v18.8h\n"
+    "sqadd v6.8h, v6.8h, v18.8h\n"
+    "smax v15.8h, v15.8h, v11.8h\n"
+    "smax v17.8h, v17.8h, v11.8h\n"
+    "smax v10.8h, v10.8h, v11.8h\n"
+    "smax v6.8h, v6.8h, v11.8h\n"
+    "smin v15.8h, v15.8h, v13.8h\n"
+    "smin v17.8h, v17.8h, v13.8h\n"
+    "smin v10.8h, v10.8h, v13.8h\n"
+    "smin v6.8h, v6.8h, v13.8h\n"
     "uzp1 v15.16b, v15.16b, v15.16b\n"
-    "str d15, [x17, x1]\n"
-    "smax v5.4s, v5.4s, v16.4s\n"
-    "sqadd v8.4s, v8.4s, v27.4s\n"
-    "srshl v11.4s, v11.4s, v21.4s\n"
-    "and v30.16b, v10.16b, v21.16b\n"
-    "sshr v30.4s, v30.4s, #0x1f\n"
-    "uzp1 v18.16b, v18.16b, v5.16b\n"
-    "add v11.4s, v11.4s, v19.4s\n"
-    "srshl v8.4s, v8.4s, v14.4s\n"
-    "uzp1 v18.16b, v18.16b, v18.16b\n"
-    "str d18, [x16, x1]\n"
-    "smin v11.4s, v11.4s, v12.4s\n"
-    "sqrdmulh v9.4s, v9.4s, v17.4s\n"
-    "add v8.4s, v8.4s, v19.4s\n"
-    "sqadd v10.4s, v10.4s, v30.4s\n"
-    "smax v11.4s, v11.4s, v16.4s\n"
-    "smin v8.4s, v8.4s, v12.4s\n"
-    "and v6.16b, v9.16b, v14.16b\n"
-    "sshr v6.4s, v6.4s, #0x1f\n"
-    "smax v8.4s, v8.4s, v16.4s\n"
-    "srshl v10.4s, v10.4s, v21.4s\n"
-    "uzp1 v11.16b, v11.16b, v8.16b\n"
-    "add v10.4s, v10.4s, v19.4s\n"
-    "uzp1 v11.16b, v11.16b, v11.16b\n"
-    "str d11, [x6, x1]\n"
-    "smin v10.4s, v10.4s, v12.4s\n"
-    "sqadd v9.4s, v9.4s, v6.4s\n"
-    "smax v10.4s, v10.4s, v16.4s\n"
-    "srshl v9.4s, v9.4s, v14.4s\n"
-    "add v9.4s, v9.4s, v19.4s\n"
-    "smin v9.4s, v9.4s, v12.4s\n"
-    "smax v9.4s, v9.4s, v16.4s\n"
-    "uzp1 v10.16b, v10.16b, v9.16b\n"
+    "uzp1 v17.16b, v17.16b, v17.16b\n"
+    "str d15, [x16, x22]\n"
     "uzp1 v10.16b, v10.16b, v10.16b\n"
-    "str d10, [x8, x1]\n"
-    "add x1, x1, #0x8\n"
+    "uzp1 v6.16b, v6.16b, v6.16b\n"
+    "str d17, [x8, x22]\n"
+    "str d10, [x4, x22]\n"
+    "str d6, [x7, x22]\n"
+    "add x22, x22, #0x8\n"
     "beq 124f\n"
-    "add x3, x3, #0xc8\n"
+    "add x23, x23, #0xc8\n"
     "3:"  // Oddments
-    "ldr x12, [%x[params], %[offsetof_Params_bias]]\n"
-    "tbz x4, #2, 5f\n"
-    "ld1 { v15.4s }, [x12], #0x10\n"
-    "tbz x4, #1, 4f\n"
-    "ld1 { v20.d }[0], [x12], #0x8\n"
-    "tbz x4, #0, 7f\n"
-    "ld1 { v20.s }[2], [x12]\n"
+    "ldr x19, [%x[params], %[offsetof_Params_bias]]\n"
+    "tbz x0, #2, 5f\n"
+    "ld1 { v15.4s }, [x19], #0x10\n"
+    "tbz x0, #1, 4f\n"
+    "ld1 { v16.d }[0], [x19], #0x8\n"
+    "tbz x0, #0, 7f\n"
+    "ld1 { v16.s }[2], [x19]\n"
     "b 7f\n"
     "4:"  // Oddments: Load bias: Bit 2: Bit 1: Unset
-    "tbz x4, #0, 7f\n"
-    "ld1 { v20.s }[0], [x12]\n"
+    "tbz x0, #0, 7f\n"
+    "ld1 { v16.s }[0], [x19]\n"
     "b 7f\n"
     "5:"  // Oddments: Load bias: Bit 2: Unset
-    "tbz x4, #1, 6f\n"
-    "ld1 { v15.d }[0], [x12], #0x8\n"
-    "tbz x4, #0, 7f\n"
-    "ld1 { v15.s }[2], [x12]\n"
+    "tbz x0, #1, 6f\n"
+    "ld1 { v15.d }[0], [x19], #0x8\n"
+    "tbz x0, #0, 7f\n"
+    "ld1 { v15.s }[2], [x19]\n"
     "b 7f\n"
     "6:"  // Oddments: Load bias: Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 7f\n"
-    "ld1 { v15.s }[0], [x12]\n"
+    "tbz x0, #0, 7f\n"
+    "ld1 { v15.s }[0], [x19]\n"
     "7:"  // Oddments: Load bias: Bit 2: End
-    "mov v18.16b, v15.16b\n"
-    "ldr d0, [x3, #0x0]\n"
-    "mov v5.16b, v20.16b\n"
-    "ldr d1, [x3, #0x8]\n"
-    "mov v11.16b, v15.16b\n"
-    "ldr d2, [x3, #0x10]\n"
-    "mov v8.16b, v20.16b\n"
-    "ldr d3, [x3, #0x18]\n"
+    "ldr d0, [x23, #0x0]\n"
+    "ldr d1, [x23, #0x8]\n"
+    "mov v17.16b, v15.16b\n"
+    "mov v8.16b, v16.16b\n"
+    "ldr d2, [x23, #0x10]\n"
+    "ldr d3, [x23, #0x18]\n"
     "mov v10.16b, v15.16b\n"
-    "ldr d4, [x3, #0x20]\n"
-    "ssubl v0.8h, v0.8b, v13.8b\n"
-    "mov v9.16b, v20.16b\n"
-    "ldp x28, x27, [x25, #0x0]\n"
-    "ssubl v1.8h, v1.8b, v13.8b\n"
-    "ldp x26, x13, [x25, #0x10]\n"
-    "ssubl v2.8h, v2.8b, v13.8b\n"
-    "ssubl v3.8h, v3.8b, v13.8b\n"
-    "ldp x24, x23, [x25, #0x20]\n"
-    "ssubl v4.8h, v4.8b, v13.8b\n"
-    "ldp x22, x21, [x25, #0x30]\n"
-    "ldp x20, x0, [x25, #0x40]\n"
-    "add x28, x28, x10\n"
-    "add x27, x27, x10\n"
-    "add x26, x26, x10\n"
-    "add x13, x13, x10\n"
-    "add x24, x24, x10\n"
-    "add x23, x23, x10\n"
-    "add x22, x22, x10\n"
-    "add x21, x21, x10\n"
-    "add x20, x20, x10\n"
-    "add x0, x0, x10\n"
-    "tbz x4, #2, 9f\n"
+    "mov v7.16b, v16.16b\n"
+    "ldr d4, [x23, #0x20]\n"
+    "ldp x28, x6, [x20, #0x0]\n"
+    "mov v6.16b, v15.16b\n"
+    "mov v5.16b, v16.16b\n"
+    "ldp x26, x25, [x20, #0x10]\n"
+    "ldp x5, x2, [x20, #0x20]\n"
+    "ssubl v0.8h, v0.8b, v14.8b\n"
+    "ssubl v1.8h, v1.8b, v14.8b\n"
+    "ldp x27, x21, [x20, #0x30]\n"
+    "ldp x12, x19, [x20, #0x40]\n"
+    "ssubl v2.8h, v2.8b, v14.8b\n"
+    "ssubl v3.8h, v3.8b, v14.8b\n"
+    "ssubl v4.8h, v4.8b, v14.8b\n"
+    "add x28, x28, x24\n"
+    "add x6, x6, x24\n"
+    "add x26, x26, x24\n"
+    "add x25, x25, x24\n"
+    "add x5, x5, x24\n"
+    "add x2, x2, x24\n"
+    "add x27, x27, x24\n"
+    "add x21, x21, x24\n"
+    "add x12, x12, x24\n"
+    "add x19, x19, x24\n"
+    "tbz x0, #2, 9f\n"
     "ld1 { v31.s }[0], [x28], #0x4\n"
-    "ld1 { v30.s }[0], [x27], #0x4\n"
+    "ld1 { v30.s }[0], [x6], #0x4\n"
     "ld1 { v29.s }[0], [x26], #0x4\n"
-    "ld1 { v28.s }[0], [x13], #0x4\n"
-    "ld1 { v27.s }[0], [x24], #0x4\n"
-    "ld1 { v23.s }[0], [x23], #0x4\n"
-    "ld1 { v25.s }[0], [x22], #0x4\n"
+    "ld1 { v28.s }[0], [x25], #0x4\n"
+    "ld1 { v27.s }[0], [x5], #0x4\n"
+    "ld1 { v23.s }[0], [x2], #0x4\n"
+    "ld1 { v25.s }[0], [x27], #0x4\n"
     "ld1 { v24.s }[0], [x21], #0x4\n"
-    "ld1 { v26.s }[0], [x20], #0x4\n"
-    "ld1 { v22.s }[0], [x0], #0x4\n"
-    "tbz x4, #1, 8f\n"
+    "ld1 { v26.s }[0], [x12], #0x4\n"
+    "ld1 { v22.s }[0], [x19], #0x4\n"
+    "tbz x0, #1, 8f\n"
     "ld1 { v31.h }[2], [x28], #0x2\n"
-    "ld1 { v30.h }[2], [x27], #0x2\n"
+    "ld1 { v30.h }[2], [x6], #0x2\n"
     "ld1 { v29.h }[2], [x26], #0x2\n"
-    "ld1 { v28.h }[2], [x13], #0x2\n"
-    "ld1 { v27.h }[2], [x24], #0x2\n"
-    "ld1 { v23.h }[2], [x23], #0x2\n"
-    "ld1 { v25.h }[2], [x22], #0x2\n"
+    "ld1 { v28.h }[2], [x25], #0x2\n"
+    "ld1 { v27.h }[2], [x5], #0x2\n"
+    "ld1 { v23.h }[2], [x2], #0x2\n"
+    "ld1 { v25.h }[2], [x27], #0x2\n"
     "ld1 { v24.h }[2], [x21], #0x2\n"
-    "ld1 { v26.h }[2], [x20], #0x2\n"
-    "ld1 { v22.h }[2], [x0], #0x2\n"
-    "tbz x4, #0, 11f\n"
+    "ld1 { v26.h }[2], [x12], #0x2\n"
+    "ld1 { v22.h }[2], [x19], #0x2\n"
+    "tbz x0, #0, 11f\n"
     "ld1 { v31.b }[6], [x28]\n"
-    "ld1 { v30.b }[6], [x27]\n"
+    "ld1 { v30.b }[6], [x6]\n"
     "ld1 { v29.b }[6], [x26]\n"
-    "ld1 { v28.b }[6], [x13]\n"
-    "ld1 { v27.b }[6], [x24]\n"
-    "ld1 { v23.b }[6], [x23]\n"
-    "ld1 { v25.b }[6], [x22]\n"
+    "ld1 { v28.b }[6], [x25]\n"
+    "ld1 { v27.b }[6], [x5]\n"
+    "ld1 { v23.b }[6], [x2]\n"
+    "ld1 { v25.b }[6], [x27]\n"
     "ld1 { v24.b }[6], [x21]\n"
-    "ld1 { v26.b }[6], [x20]\n"
-    "ld1 { v22.b }[6], [x0]\n"
+    "ld1 { v26.b }[6], [x12]\n"
+    "ld1 { v22.b }[6], [x19]\n"
     "b 11f\n"
     "8:"  // Oddments: Initial loads: Bit 2: Bit 1: Unset
-    "tbz x4, #0, 11f\n"
+    "tbz x0, #0, 11f\n"
     "ld1 { v31.b }[4], [x28]\n"
-    "ld1 { v30.b }[4], [x27]\n"
+    "ld1 { v30.b }[4], [x6]\n"
     "ld1 { v29.b }[4], [x26]\n"
-    "ld1 { v28.b }[4], [x13]\n"
-    "ld1 { v27.b }[4], [x24]\n"
-    "ld1 { v23.b }[4], [x23]\n"
-    "ld1 { v25.b }[4], [x22]\n"
+    "ld1 { v28.b }[4], [x25]\n"
+    "ld1 { v27.b }[4], [x5]\n"
+    "ld1 { v23.b }[4], [x2]\n"
+    "ld1 { v25.b }[4], [x27]\n"
     "ld1 { v24.b }[4], [x21]\n"
-    "ld1 { v26.b }[4], [x20]\n"
-    "ld1 { v22.b }[4], [x0]\n"
+    "ld1 { v26.b }[4], [x12]\n"
+    "ld1 { v22.b }[4], [x19]\n"
     "b 11f\n"
     "9:"  // Oddments: Initial loads: Bit 2: Unset
-    "tbz x4, #1, 10f\n"
+    "tbz x0, #1, 10f\n"
     "ld1 { v31.h }[0], [x28], #0x2\n"
-    "ld1 { v30.h }[0], [x27], #0x2\n"
+    "ld1 { v30.h }[0], [x6], #0x2\n"
     "ld1 { v29.h }[0], [x26], #0x2\n"
-    "ld1 { v28.h }[0], [x13], #0x2\n"
-    "ld1 { v27.h }[0], [x24], #0x2\n"
-    "ld1 { v23.h }[0], [x23], #0x2\n"
-    "ld1 { v25.h }[0], [x22], #0x2\n"
+    "ld1 { v28.h }[0], [x25], #0x2\n"
+    "ld1 { v27.h }[0], [x5], #0x2\n"
+    "ld1 { v23.h }[0], [x2], #0x2\n"
+    "ld1 { v25.h }[0], [x27], #0x2\n"
     "ld1 { v24.h }[0], [x21], #0x2\n"
-    "ld1 { v26.h }[0], [x20], #0x2\n"
-    "ld1 { v22.h }[0], [x0], #0x2\n"
-    "tbz x4, #0, 11f\n"
+    "ld1 { v26.h }[0], [x12], #0x2\n"
+    "ld1 { v22.h }[0], [x19], #0x2\n"
+    "tbz x0, #0, 11f\n"
     "ld1 { v31.b }[2], [x28]\n"
-    "ld1 { v30.b }[2], [x27]\n"
+    "ld1 { v30.b }[2], [x6]\n"
     "ld1 { v29.b }[2], [x26]\n"
-    "ld1 { v28.b }[2], [x13]\n"
-    "ld1 { v27.b }[2], [x24]\n"
-    "ld1 { v23.b }[2], [x23]\n"
-    "ld1 { v25.b }[2], [x22]\n"
+    "ld1 { v28.b }[2], [x25]\n"
+    "ld1 { v27.b }[2], [x5]\n"
+    "ld1 { v23.b }[2], [x2]\n"
+    "ld1 { v25.b }[2], [x27]\n"
     "ld1 { v24.b }[2], [x21]\n"
-    "ld1 { v26.b }[2], [x20]\n"
-    "ld1 { v22.b }[2], [x0]\n"
+    "ld1 { v26.b }[2], [x12]\n"
+    "ld1 { v22.b }[2], [x19]\n"
     "b 11f\n"
     "10:"  // Oddments: Initial loads: Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 11f\n"
+    "tbz x0, #0, 11f\n"
     "ld1 { v31.b }[0], [x28]\n"
-    "ld1 { v30.b }[0], [x27]\n"
+    "ld1 { v30.b }[0], [x6]\n"
     "ld1 { v29.b }[0], [x26]\n"
-    "ld1 { v28.b }[0], [x13]\n"
-    "ld1 { v27.b }[0], [x24]\n"
-    "ld1 { v23.b }[0], [x23]\n"
-    "ld1 { v25.b }[0], [x22]\n"
+    "ld1 { v28.b }[0], [x25]\n"
+    "ld1 { v27.b }[0], [x5]\n"
+    "ld1 { v23.b }[0], [x2]\n"
+    "ld1 { v25.b }[0], [x27]\n"
     "ld1 { v24.b }[0], [x21]\n"
-    "ld1 { v26.b }[0], [x20]\n"
-    "ld1 { v22.b }[0], [x0]\n"
+    "ld1 { v26.b }[0], [x12]\n"
+    "ld1 { v22.b }[0], [x19]\n"
     "11:"  // Oddments: Initial loads: Bit 2: End
-    "ldr x20, [x25, #0x50]\n"
-    "ssubl v31.8h, v31.8b, v7.8b\n"
+    "ssubl v31.8h, v31.8b, v9.8b\n"
+    "ssubl v30.8h, v30.8b, v9.8b\n"
     "smlal v15.4s, v31.4h, v0.4h\n"
-    "ssubl v30.8h, v30.8b, v7.8b\n"
-    "smlal2 v20.4s, v31.8h, v0.8h\n"
-    "ssubl v29.8h, v29.8b, v7.8b\n"
-    "ssubl v28.8h, v28.8b, v7.8b\n"
-    "smlal v18.4s, v30.4h, v0.4h\n"
-    "ssubl v27.8h, v27.8b, v7.8b\n"
-    "smlal2 v5.4s, v30.8h, v0.8h\n"
-    "ssubl v23.8h, v23.8b, v7.8b\n"
-    "smlal v11.4s, v29.4h, v0.4h\n"
-    "ssubl v25.8h, v25.8b, v7.8b\n"
-    "smlal2 v8.4s, v29.8h, v0.8h\n"
-    "ssubl v24.8h, v24.8b, v7.8b\n"
-    "smlal v10.4s, v28.4h, v0.4h\n"
-    "ssubl v26.8h, v26.8b, v7.8b\n"
-    "smlal2 v9.4s, v28.8h, v0.8h\n"
-    "ssubl v22.8h, v22.8b, v7.8b\n"
+    "ldr x19, [x20, #0x50]\n"
+    "ssubl v29.8h, v29.8b, v9.8b\n"
+    "smlal2 v16.4s, v31.8h, v0.8h\n"
+    "smlal v17.4s, v30.4h, v0.4h\n"
+    "smlal2 v8.4s, v30.8h, v0.8h\n"
+    "smlal v10.4s, v29.4h, v0.4h\n"
+    "ssubl v28.8h, v28.8b, v9.8b\n"
+    "add x19, x19, x24\n"
+    "smlal2 v7.4s, v29.8h, v0.8h\n"
+    "ssubl v27.8h, v27.8b, v9.8b\n"
+    "smlal v6.4s, v28.4h, v0.4h\n"
+    "smlal2 v5.4s, v28.8h, v0.8h\n"
     "smlal v15.4s, v30.4h, v1.4h\n"
-    "smlal2 v20.4s, v30.8h, v1.8h\n"
-    "add x20, x20, x10\n"
-    "smlal v18.4s, v27.4h, v1.4h\n"
-    "smlal2 v5.4s, v27.8h, v1.8h\n"
-    "smlal v11.4s, v28.4h, v1.4h\n"
-    "smlal2 v8.4s, v28.8h, v1.8h\n"
-    "smlal v10.4s, v23.4h, v1.4h\n"
-    "smlal2 v9.4s, v23.8h, v1.8h\n"
+    "ssubl v23.8h, v23.8b, v9.8b\n"
+    "smlal2 v16.4s, v30.8h, v1.8h\n"
+    "smlal v17.4s, v27.4h, v1.4h\n"
+    "ssubl v25.8h, v25.8b, v9.8b\n"
+    "smlal2 v8.4s, v27.8h, v1.8h\n"
+    "smlal v10.4s, v28.4h, v1.4h\n"
+    "ssubl v24.8h, v24.8b, v9.8b\n"
+    "smlal2 v7.4s, v28.8h, v1.8h\n"
+    "ssubl v26.8h, v26.8b, v9.8b\n"
+    "smlal v6.4s, v23.4h, v1.4h\n"
+    "ssubl v22.8h, v22.8b, v9.8b\n"
+    "smlal2 v5.4s, v23.8h, v1.8h\n"
     "smlal v15.4s, v27.4h, v2.4h\n"
-    "smlal2 v20.4s, v27.8h, v2.8h\n"
-    "smlal v18.4s, v25.4h, v2.4h\n"
-    "smlal2 v5.4s, v25.8h, v2.8h\n"
-    "smlal v11.4s, v23.4h, v2.4h\n"
-    "smlal2 v8.4s, v23.8h, v2.8h\n"
-    "tbz x4, #2, 13f\n"
-    "ld1 { v31.s }[0], [x20], #0x4\n"
-    "tbz x4, #1, 12f\n"
-    "ld1 { v31.h }[2], [x20], #0x2\n"
-    "tbz x4, #0, 15f\n"
-    "ld1 { v31.b }[6], [x20]\n"
+    "smlal2 v16.4s, v27.8h, v2.8h\n"
+    "smlal v17.4s, v25.4h, v2.4h\n"
+    "smlal2 v8.4s, v25.8h, v2.8h\n"
+    "smlal v10.4s, v23.4h, v2.4h\n"
+    "smlal2 v7.4s, v23.8h, v2.8h\n"
+    "tbz x0, #2, 13f\n"
+    "ld1 { v31.s }[0], [x19], #0x4\n"
+    "tbz x0, #1, 12f\n"
+    "ld1 { v31.h }[2], [x19], #0x2\n"
+    "tbz x0, #0, 15f\n"
+    "ld1 { v31.b }[6], [x19]\n"
     "b 15f\n"
     "12:"  // Oddments: Load (1, 3): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 15f\n"
-    "ld1 { v31.b }[4], [x20]\n"
+    "tbz x0, #0, 15f\n"
+    "ld1 { v31.b }[4], [x19]\n"
     "b 15f\n"
     "13:"  // Oddments: Load (1, 3): Bit 2: Unset
-    "tbz x4, #1, 14f\n"
-    "ld1 { v31.h }[0], [x20], #0x2\n"
-    "tbz x4, #0, 15f\n"
-    "ld1 { v31.b }[2], [x20]\n"
+    "tbz x0, #1, 14f\n"
+    "ld1 { v31.h }[0], [x19], #0x2\n"
+    "tbz x0, #0, 15f\n"
+    "ld1 { v31.b }[2], [x19]\n"
     "b 15f\n"
     "14:"  // Oddments: Load (1, 3): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 15f\n"
-    "ld1 { v31.b }[0], [x20]\n"
+    "tbz x0, #0, 15f\n"
+    "ld1 { v31.b }[0], [x19]\n"
     "15:"  // Oddments: Load (1, 3): Bit 2: End
+    "ssubl v31.8h, v31.8b, v9.8b\n"
+    "ldr x15, [x20, #0x58]\n"
+    "smlal v6.4s, v31.4h, v2.4h\n"
+    "smlal2 v5.4s, v31.8h, v2.8h\n"
     "smlal v15.4s, v25.4h, v3.4h\n"
-    "ldr x28, [x25, #0x58]\n"
-    "ssubl v31.8h, v31.8b, v7.8b\n"
-    "smlal2 v20.4s, v25.8h, v3.8h\n"
-    "smlal v18.4s, v24.4h, v3.4h\n"
-    "add x28, x28, x10\n"
-    "smlal2 v5.4s, v24.8h, v3.8h\n"
-    "smlal v10.4s, v31.4h, v2.4h\n"
-    "smlal2 v9.4s, v31.8h, v2.8h\n"
-    "smlal v11.4s, v31.4h, v3.4h\n"
-    "smlal2 v8.4s, v31.8h, v3.8h\n"
-    "tbz x4, #2, 17f\n"
-    "ld1 { v30.s }[0], [x28], #0x4\n"
-    "tbz x4, #1, 16f\n"
-    "ld1 { v30.h }[2], [x28], #0x2\n"
-    "tbz x4, #0, 19f\n"
-    "ld1 { v30.b }[6], [x28]\n"
+    "smlal2 v16.4s, v25.8h, v3.8h\n"
+    "add x15, x15, x24\n"
+    "smlal v17.4s, v24.4h, v3.4h\n"
+    "smlal2 v8.4s, v24.8h, v3.8h\n"
+    "smlal v10.4s, v31.4h, v3.4h\n"
+    "smlal2 v7.4s, v31.8h, v3.8h\n"
+    "tbz x0, #2, 17f\n"
+    "ld1 { v30.s }[0], [x15], #0x4\n"
+    "tbz x0, #1, 16f\n"
+    "ld1 { v30.h }[2], [x15], #0x2\n"
+    "tbz x0, #0, 19f\n"
+    "ld1 { v30.b }[6], [x15]\n"
     "b 19f\n"
     "16:"  // Oddments: Load (1, 4): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 19f\n"
-    "ld1 { v30.b }[4], [x28]\n"
+    "tbz x0, #0, 19f\n"
+    "ld1 { v30.b }[4], [x15]\n"
     "b 19f\n"
     "17:"  // Oddments: Load (1, 4): Bit 2: Unset
-    "tbz x4, #1, 18f\n"
-    "ld1 { v30.h }[0], [x28], #0x2\n"
-    "tbz x4, #0, 19f\n"
-    "ld1 { v30.b }[2], [x28]\n"
+    "tbz x0, #1, 18f\n"
+    "ld1 { v30.h }[0], [x15], #0x2\n"
+    "tbz x0, #0, 19f\n"
+    "ld1 { v30.b }[2], [x15]\n"
     "b 19f\n"
     "18:"  // Oddments: Load (1, 4): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 19f\n"
-    "ld1 { v30.b }[0], [x28]\n"
+    "tbz x0, #0, 19f\n"
+    "ld1 { v30.b }[0], [x15]\n"
     "19:"  // Oddments: Load (1, 4): Bit 2: End
+    "ssubl v30.8h, v30.8b, v9.8b\n"
+    "ldr x19, [x20, #0x60]\n"
+    "smlal v6.4s, v30.4h, v3.4h\n"
+    "smlal2 v5.4s, v30.8h, v3.8h\n"
     "smlal v15.4s, v24.4h, v4.4h\n"
-    "ldr x0, [x25, #0x60]\n"
-    "ssubl v30.8h, v30.8b, v7.8b\n"
-    "smlal2 v20.4s, v24.8h, v4.8h\n"
-    "add x0, x0, x10\n"
-    "smlal v10.4s, v30.4h, v3.4h\n"
-    "smlal2 v9.4s, v30.8h, v3.8h\n"
-    "tbz x4, #2, 21f\n"
-    "ld1 { v27.s }[0], [x0], #0x4\n"
-    "tbz x4, #1, 20f\n"
-    "ld1 { v27.h }[2], [x0], #0x2\n"
-    "tbz x4, #0, 23f\n"
-    "ld1 { v27.b }[6], [x0]\n"
+    "smlal2 v16.4s, v24.8h, v4.8h\n"
+    "add x19, x19, x24\n"
+    "tbz x0, #2, 21f\n"
+    "ld1 { v27.s }[0], [x19], #0x4\n"
+    "tbz x0, #1, 20f\n"
+    "ld1 { v27.h }[2], [x19], #0x2\n"
+    "tbz x0, #0, 23f\n"
+    "ld1 { v27.b }[6], [x19]\n"
     "b 23f\n"
     "20:"  // Oddments: Load (0, 5): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 23f\n"
-    "ld1 { v27.b }[4], [x0]\n"
+    "tbz x0, #0, 23f\n"
+    "ld1 { v27.b }[4], [x19]\n"
     "b 23f\n"
     "21:"  // Oddments: Load (0, 5): Bit 2: Unset
-    "tbz x4, #1, 22f\n"
-    "ld1 { v27.h }[0], [x0], #0x2\n"
-    "tbz x4, #0, 23f\n"
-    "ld1 { v27.b }[2], [x0]\n"
+    "tbz x0, #1, 22f\n"
+    "ld1 { v27.h }[0], [x19], #0x2\n"
+    "tbz x0, #0, 23f\n"
+    "ld1 { v27.b }[2], [x19]\n"
     "b 23f\n"
     "22:"  // Oddments: Load (0, 5): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 23f\n"
-    "ld1 { v27.b }[0], [x0]\n"
+    "tbz x0, #0, 23f\n"
+    "ld1 { v27.b }[0], [x19]\n"
     "23:"  // Oddments: Load (0, 5): Bit 2: End
-    "smlal v11.4s, v30.4h, v4.4h\n"
-    "ldr d0, [x3, #0x28]\n"
-    "ssubl v27.8h, v27.8b, v7.8b\n"
-    "smlal2 v8.4s, v30.8h, v4.8h\n"
-    "ldr x7, [x25, #0x68]\n"
-    "smlal v10.4s, v26.4h, v4.4h\n"
-    "ssubl v0.8h, v0.8b, v13.8b\n"
-    "smlal2 v9.4s, v26.8h, v4.8h\n"
-    "add x7, x7, x10\n"
-    "smlal v18.4s, v27.4h, v4.4h\n"
-    "smlal2 v5.4s, v27.8h, v4.8h\n"
+    "ssubl v27.8h, v27.8b, v9.8b\n"
+    "ldr d0, [x23, #0x28]\n"
+    "smlal v17.4s, v27.4h, v4.4h\n"
+    "smlal2 v8.4s, v27.8h, v4.8h\n"
+    "smlal v10.4s, v30.4h, v4.4h\n"
+    "smlal2 v7.4s, v30.8h, v4.8h\n"
+    "ssubl v0.8h, v0.8b, v14.8b\n"
+    "ldr x27, [x20, #0x68]\n"
+    "smlal v6.4s, v26.4h, v4.4h\n"
+    "smlal2 v5.4s, v26.8h, v4.8h\n"
+    "add x27, x27, x24\n"
     "smlal v15.4s, v29.4h, v0.4h\n"
-    "smlal2 v20.4s, v29.8h, v0.8h\n"
-    "smlal v18.4s, v28.4h, v0.4h\n"
-    "smlal2 v5.4s, v28.8h, v0.8h\n"
-    "smlal v11.4s, v22.4h, v0.4h\n"
-    "smlal2 v8.4s, v22.8h, v0.8h\n"
-    "tbz x4, #2, 25f\n"
-    "ld1 { v25.s }[0], [x7], #0x4\n"
-    "tbz x4, #1, 24f\n"
-    "ld1 { v25.h }[2], [x7], #0x2\n"
-    "tbz x4, #0, 27f\n"
-    "ld1 { v25.b }[6], [x7]\n"
+    "smlal2 v16.4s, v29.8h, v0.8h\n"
+    "smlal v17.4s, v28.4h, v0.4h\n"
+    "smlal2 v8.4s, v28.8h, v0.8h\n"
+    "smlal v10.4s, v22.4h, v0.4h\n"
+    "smlal2 v7.4s, v22.8h, v0.8h\n"
+    "tbz x0, #2, 25f\n"
+    "ld1 { v25.s }[0], [x27], #0x4\n"
+    "tbz x0, #1, 24f\n"
+    "ld1 { v25.h }[2], [x27], #0x2\n"
+    "tbz x0, #0, 27f\n"
+    "ld1 { v25.b }[6], [x27]\n"
     "b 27f\n"
     "24:"  // Oddments: Load (2, 1): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 27f\n"
-    "ld1 { v25.b }[4], [x7]\n"
+    "tbz x0, #0, 27f\n"
+    "ld1 { v25.b }[4], [x27]\n"
     "b 27f\n"
     "25:"  // Oddments: Load (2, 1): Bit 2: Unset
-    "tbz x4, #1, 26f\n"
-    "ld1 { v25.h }[0], [x7], #0x2\n"
-    "tbz x4, #0, 27f\n"
-    "ld1 { v25.b }[2], [x7]\n"
+    "tbz x0, #1, 26f\n"
+    "ld1 { v25.h }[0], [x27], #0x2\n"
+    "tbz x0, #0, 27f\n"
+    "ld1 { v25.b }[2], [x27]\n"
     "b 27f\n"
     "26:"  // Oddments: Load (2, 1): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 27f\n"
-    "ld1 { v25.b }[0], [x7]\n"
+    "tbz x0, #0, 27f\n"
+    "ld1 { v25.b }[0], [x27]\n"
     "27:"  // Oddments: Load (2, 1): Bit 2: End
-    "ldr d1, [x3, #0x30]\n"
-    "ssubl v25.8h, v25.8b, v7.8b\n"
-    "smlal v10.4s, v25.4h, v0.4h\n"
-    "ldr x26, [x25, #0x70]\n"
-    "ssubl v1.8h, v1.8b, v13.8b\n"
-    "smlal2 v9.4s, v25.8h, v0.8h\n"
-    "add x26, x26, x10\n"
+    "ldr d1, [x23, #0x30]\n"
+    "ssubl v25.8h, v25.8b, v9.8b\n"
+    "ssubl v1.8h, v1.8b, v14.8b\n"
+    "ldr x5, [x20, #0x70]\n"
+    "smlal v6.4s, v25.4h, v0.4h\n"
+    "smlal2 v5.4s, v25.8h, v0.8h\n"
+    "add x5, x5, x24\n"
     "smlal v15.4s, v28.4h, v1.4h\n"
-    "smlal2 v20.4s, v28.8h, v1.8h\n"
-    "smlal v18.4s, v23.4h, v1.4h\n"
-    "smlal2 v5.4s, v23.8h, v1.8h\n"
-    "smlal v11.4s, v25.4h, v1.4h\n"
-    "smlal2 v8.4s, v25.8h, v1.8h\n"
-    "tbz x4, #2, 29f\n"
-    "ld1 { v24.s }[0], [x26], #0x4\n"
-    "tbz x4, #1, 28f\n"
-    "ld1 { v24.h }[2], [x26], #0x2\n"
-    "tbz x4, #0, 31f\n"
-    "ld1 { v24.b }[6], [x26]\n"
+    "smlal2 v16.4s, v28.8h, v1.8h\n"
+    "smlal v17.4s, v23.4h, v1.4h\n"
+    "smlal2 v8.4s, v23.8h, v1.8h\n"
+    "smlal v10.4s, v25.4h, v1.4h\n"
+    "smlal2 v7.4s, v25.8h, v1.8h\n"
+    "tbz x0, #2, 29f\n"
+    "ld1 { v24.s }[0], [x5], #0x4\n"
+    "tbz x0, #1, 28f\n"
+    "ld1 { v24.h }[2], [x5], #0x2\n"
+    "tbz x0, #0, 31f\n"
+    "ld1 { v24.b }[6], [x5]\n"
     "b 31f\n"
     "28:"  // Oddments: Load (2, 2): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 31f\n"
-    "ld1 { v24.b }[4], [x26]\n"
+    "tbz x0, #0, 31f\n"
+    "ld1 { v24.b }[4], [x5]\n"
     "b 31f\n"
     "29:"  // Oddments: Load (2, 2): Bit 2: Unset
-    "tbz x4, #1, 30f\n"
-    "ld1 { v24.h }[0], [x26], #0x2\n"
-    "tbz x4, #0, 31f\n"
-    "ld1 { v24.b }[2], [x26]\n"
+    "tbz x0, #1, 30f\n"
+    "ld1 { v24.h }[0], [x5], #0x2\n"
+    "tbz x0, #0, 31f\n"
+    "ld1 { v24.b }[2], [x5]\n"
     "b 31f\n"
     "30:"  // Oddments: Load (2, 2): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 31f\n"
-    "ld1 { v24.b }[0], [x26]\n"
+    "tbz x0, #0, 31f\n"
+    "ld1 { v24.b }[0], [x5]\n"
     "31:"  // Oddments: Load (2, 2): Bit 2: End
-    "ldr d2, [x3, #0x38]\n"
-    "ssubl v24.8h, v24.8b, v7.8b\n"
-    "smlal v10.4s, v24.4h, v1.4h\n"
-    "ldr x23, [x25, #0x78]\n"
-    "ssubl v2.8h, v2.8b, v13.8b\n"
-    "smlal2 v9.4s, v24.8h, v1.8h\n"
-    "add x23, x23, x10\n"
+    "ldr d2, [x23, #0x38]\n"
+    "ssubl v24.8h, v24.8b, v9.8b\n"
+    "ssubl v2.8h, v2.8b, v14.8b\n"
+    "ldr x11, [x20, #0x78]\n"
+    "smlal v6.4s, v24.4h, v1.4h\n"
+    "smlal2 v5.4s, v24.8h, v1.8h\n"
+    "add x11, x11, x24\n"
     "smlal v15.4s, v23.4h, v2.4h\n"
-    "smlal2 v20.4s, v23.8h, v2.8h\n"
-    "smlal v18.4s, v31.4h, v2.4h\n"
-    "smlal2 v5.4s, v31.8h, v2.8h\n"
-    "smlal v11.4s, v24.4h, v2.4h\n"
-    "smlal2 v8.4s, v24.8h, v2.8h\n"
-    "tbz x4, #2, 33f\n"
-    "ld1 { v27.s }[0], [x23], #0x4\n"
-    "tbz x4, #1, 32f\n"
-    "ld1 { v27.h }[2], [x23], #0x2\n"
-    "tbz x4, #0, 35f\n"
-    "ld1 { v27.b }[6], [x23]\n"
+    "smlal2 v16.4s, v23.8h, v2.8h\n"
+    "smlal v17.4s, v31.4h, v2.4h\n"
+    "smlal2 v8.4s, v31.8h, v2.8h\n"
+    "smlal v10.4s, v24.4h, v2.4h\n"
+    "smlal2 v7.4s, v24.8h, v2.8h\n"
+    "tbz x0, #2, 33f\n"
+    "ld1 { v27.s }[0], [x11], #0x4\n"
+    "tbz x0, #1, 32f\n"
+    "ld1 { v27.h }[2], [x11], #0x2\n"
+    "tbz x0, #0, 35f\n"
+    "ld1 { v27.b }[6], [x11]\n"
     "b 35f\n"
     "32:"  // Oddments: Load (2, 3): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 35f\n"
-    "ld1 { v27.b }[4], [x23]\n"
+    "tbz x0, #0, 35f\n"
+    "ld1 { v27.b }[4], [x11]\n"
     "b 35f\n"
     "33:"  // Oddments: Load (2, 3): Bit 2: Unset
-    "tbz x4, #1, 34f\n"
-    "ld1 { v27.h }[0], [x23], #0x2\n"
-    "tbz x4, #0, 35f\n"
-    "ld1 { v27.b }[2], [x23]\n"
+    "tbz x0, #1, 34f\n"
+    "ld1 { v27.h }[0], [x11], #0x2\n"
+    "tbz x0, #0, 35f\n"
+    "ld1 { v27.b }[2], [x11]\n"
     "b 35f\n"
     "34:"  // Oddments: Load (2, 3): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 35f\n"
-    "ld1 { v27.b }[0], [x23]\n"
+    "tbz x0, #0, 35f\n"
+    "ld1 { v27.b }[0], [x11]\n"
     "35:"  // Oddments: Load (2, 3): Bit 2: End
-    "ldr d3, [x3, #0x40]\n"
-    "ssubl v27.8h, v27.8b, v7.8b\n"
-    "smlal v10.4s, v27.4h, v2.4h\n"
-    "ldr x20, [x25, #0x80]\n"
-    "ssubl v3.8h, v3.8b, v13.8b\n"
-    "smlal2 v9.4s, v27.8h, v2.8h\n"
-    "add x20, x20, x10\n"
+    "ldr d3, [x23, #0x40]\n"
+    "ssubl v27.8h, v27.8b, v9.8b\n"
+    "ssubl v3.8h, v3.8b, v14.8b\n"
+    "ldr x12, [x20, #0x80]\n"
+    "smlal v6.4s, v27.4h, v2.4h\n"
+    "smlal2 v5.4s, v27.8h, v2.8h\n"
+    "add x12, x12, x24\n"
     "smlal v15.4s, v31.4h, v3.4h\n"
-    "smlal2 v20.4s, v31.8h, v3.8h\n"
-    "smlal v18.4s, v30.4h, v3.4h\n"
-    "smlal2 v5.4s, v30.8h, v3.8h\n"
-    "smlal v11.4s, v27.4h, v3.4h\n"
-    "smlal2 v8.4s, v27.8h, v3.8h\n"
-    "tbz x4, #2, 37f\n"
-    "ld1 { v23.s }[0], [x20], #0x4\n"
-    "tbz x4, #1, 36f\n"
-    "ld1 { v23.h }[2], [x20], #0x2\n"
-    "tbz x4, #0, 39f\n"
-    "ld1 { v23.b }[6], [x20]\n"
+    "smlal2 v16.4s, v31.8h, v3.8h\n"
+    "smlal v17.4s, v30.4h, v3.4h\n"
+    "smlal2 v8.4s, v30.8h, v3.8h\n"
+    "smlal v10.4s, v27.4h, v3.4h\n"
+    "smlal2 v7.4s, v27.8h, v3.8h\n"
+    "tbz x0, #2, 37f\n"
+    "ld1 { v23.s }[0], [x12], #0x4\n"
+    "tbz x0, #1, 36f\n"
+    "ld1 { v23.h }[2], [x12], #0x2\n"
+    "tbz x0, #0, 39f\n"
+    "ld1 { v23.b }[6], [x12]\n"
     "b 39f\n"
     "36:"  // Oddments: Load (2, 4): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 39f\n"
-    "ld1 { v23.b }[4], [x20]\n"
+    "tbz x0, #0, 39f\n"
+    "ld1 { v23.b }[4], [x12]\n"
     "b 39f\n"
     "37:"  // Oddments: Load (2, 4): Bit 2: Unset
-    "tbz x4, #1, 38f\n"
-    "ld1 { v23.h }[0], [x20], #0x2\n"
-    "tbz x4, #0, 39f\n"
-    "ld1 { v23.b }[2], [x20]\n"
+    "tbz x0, #1, 38f\n"
+    "ld1 { v23.h }[0], [x12], #0x2\n"
+    "tbz x0, #0, 39f\n"
+    "ld1 { v23.b }[2], [x12]\n"
     "b 39f\n"
     "38:"  // Oddments: Load (2, 4): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 39f\n"
-    "ld1 { v23.b }[0], [x20]\n"
+    "tbz x0, #0, 39f\n"
+    "ld1 { v23.b }[0], [x12]\n"
     "39:"  // Oddments: Load (2, 4): Bit 2: End
-    "ldr d4, [x3, #0x48]\n"
-    "ssubl v23.8h, v23.8b, v7.8b\n"
-    "smlal v10.4s, v23.4h, v3.4h\n"
-    "ldr x22, [x25, #0x88]\n"
-    "ssubl v4.8h, v4.8b, v13.8b\n"
-    "smlal2 v9.4s, v23.8h, v3.8h\n"
-    "add x22, x22, x10\n"
+    "ldr d4, [x23, #0x48]\n"
+    "ssubl v23.8h, v23.8b, v9.8b\n"
+    "ssubl v4.8h, v4.8b, v14.8b\n"
+    "ldr x26, [x20, #0x88]\n"
+    "smlal v6.4s, v23.4h, v3.4h\n"
+    "smlal2 v5.4s, v23.8h, v3.8h\n"
+    "add x26, x26, x24\n"
     "smlal v15.4s, v30.4h, v4.4h\n"
-    "smlal2 v20.4s, v30.8h, v4.8h\n"
-    "smlal v18.4s, v26.4h, v4.4h\n"
-    "smlal2 v5.4s, v26.8h, v4.8h\n"
-    "smlal v11.4s, v23.4h, v4.4h\n"
-    "smlal2 v8.4s, v23.8h, v4.8h\n"
-    "tbz x4, #2, 41f\n"
-    "ld1 { v28.s }[0], [x22], #0x4\n"
-    "tbz x4, #1, 40f\n"
-    "ld1 { v28.h }[2], [x22], #0x2\n"
-    "tbz x4, #0, 43f\n"
-    "ld1 { v28.b }[6], [x22]\n"
+    "smlal2 v16.4s, v30.8h, v4.8h\n"
+    "smlal v17.4s, v26.4h, v4.4h\n"
+    "smlal2 v8.4s, v26.8h, v4.8h\n"
+    "smlal v10.4s, v23.4h, v4.4h\n"
+    "smlal2 v7.4s, v23.8h, v4.8h\n"
+    "tbz x0, #2, 41f\n"
+    "ld1 { v28.s }[0], [x26], #0x4\n"
+    "tbz x0, #1, 40f\n"
+    "ld1 { v28.h }[2], [x26], #0x2\n"
+    "tbz x0, #0, 43f\n"
+    "ld1 { v28.b }[6], [x26]\n"
     "b 43f\n"
     "40:"  // Oddments: Load (2, 5): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 43f\n"
-    "ld1 { v28.b }[4], [x22]\n"
+    "tbz x0, #0, 43f\n"
+    "ld1 { v28.b }[4], [x26]\n"
     "b 43f\n"
     "41:"  // Oddments: Load (2, 5): Bit 2: Unset
-    "tbz x4, #1, 42f\n"
-    "ld1 { v28.h }[0], [x22], #0x2\n"
-    "tbz x4, #0, 43f\n"
-    "ld1 { v28.b }[2], [x22]\n"
+    "tbz x0, #1, 42f\n"
+    "ld1 { v28.h }[0], [x26], #0x2\n"
+    "tbz x0, #0, 43f\n"
+    "ld1 { v28.b }[2], [x26]\n"
     "b 43f\n"
     "42:"  // Oddments: Load (2, 5): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 43f\n"
-    "ld1 { v28.b }[0], [x22]\n"
+    "tbz x0, #0, 43f\n"
+    "ld1 { v28.b }[0], [x26]\n"
     "43:"  // Oddments: Load (2, 5): Bit 2: End
-    "ldr d0, [x3, #0x50]\n"
-    "ssubl v28.8h, v28.8b, v7.8b\n"
-    "smlal v10.4s, v28.4h, v4.4h\n"
-    "ldr x13, [x25, #0x90]\n"
-    "ssubl v0.8h, v0.8b, v13.8b\n"
-    "smlal2 v9.4s, v28.8h, v4.8h\n"
-    "add x13, x13, x10\n"
+    "ldr d0, [x23, #0x50]\n"
+    "ssubl v28.8h, v28.8b, v9.8b\n"
+    "ssubl v0.8h, v0.8b, v14.8b\n"
+    "ldr x14, [x20, #0x90]\n"
+    "smlal v6.4s, v28.4h, v4.4h\n"
+    "smlal2 v5.4s, v28.8h, v4.8h\n"
+    "add x14, x14, x24\n"
     "smlal v15.4s, v22.4h, v0.4h\n"
-    "smlal2 v20.4s, v22.8h, v0.8h\n"
-    "smlal v18.4s, v25.4h, v0.4h\n"
-    "smlal2 v5.4s, v25.8h, v0.8h\n"
-    "tbz x4, #2, 45f\n"
-    "ld1 { v31.s }[0], [x13], #0x4\n"
-    "tbz x4, #1, 44f\n"
-    "ld1 { v31.h }[2], [x13], #0x2\n"
-    "tbz x4, #0, 47f\n"
-    "ld1 { v31.b }[6], [x13]\n"
+    "smlal2 v16.4s, v22.8h, v0.8h\n"
+    "smlal v17.4s, v25.4h, v0.4h\n"
+    "smlal2 v8.4s, v25.8h, v0.8h\n"
+    "tbz x0, #2, 45f\n"
+    "ld1 { v31.s }[0], [x14], #0x4\n"
+    "tbz x0, #1, 44f\n"
+    "ld1 { v31.h }[2], [x14], #0x2\n"
+    "tbz x0, #0, 47f\n"
+    "ld1 { v31.b }[6], [x14]\n"
     "b 47f\n"
     "44:"  // Oddments: Load (3, 0): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 47f\n"
-    "ld1 { v31.b }[4], [x13]\n"
+    "tbz x0, #0, 47f\n"
+    "ld1 { v31.b }[4], [x14]\n"
     "b 47f\n"
     "45:"  // Oddments: Load (3, 0): Bit 2: Unset
-    "tbz x4, #1, 46f\n"
-    "ld1 { v31.h }[0], [x13], #0x2\n"
-    "tbz x4, #0, 47f\n"
-    "ld1 { v31.b }[2], [x13]\n"
+    "tbz x0, #1, 46f\n"
+    "ld1 { v31.h }[0], [x14], #0x2\n"
+    "tbz x0, #0, 47f\n"
+    "ld1 { v31.b }[2], [x14]\n"
     "b 47f\n"
     "46:"  // Oddments: Load (3, 0): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 47f\n"
-    "ld1 { v31.b }[0], [x13]\n"
+    "tbz x0, #0, 47f\n"
+    "ld1 { v31.b }[0], [x14]\n"
     "47:"  // Oddments: Load (3, 0): Bit 2: End
-    "ldr x21, [x25, #0x98]\n"
-    "ssubl v31.8h, v31.8b, v7.8b\n"
-    "smlal v11.4s, v31.4h, v0.4h\n"
-    "smlal2 v8.4s, v31.8h, v0.8h\n"
-    "add x21, x21, x10\n"
-    "tbz x4, #2, 49f\n"
-    "ld1 { v30.s }[0], [x21], #0x4\n"
-    "tbz x4, #1, 48f\n"
-    "ld1 { v30.h }[2], [x21], #0x2\n"
-    "tbz x4, #0, 51f\n"
-    "ld1 { v30.b }[6], [x21]\n"
+    "ssubl v31.8h, v31.8b, v9.8b\n"
+    "ldr x15, [x20, #0x98]\n"
+    "smlal v10.4s, v31.4h, v0.4h\n"
+    "smlal2 v7.4s, v31.8h, v0.8h\n"
+    "add x15, x15, x24\n"
+    "tbz x0, #2, 49f\n"
+    "ld1 { v30.s }[0], [x15], #0x4\n"
+    "tbz x0, #1, 48f\n"
+    "ld1 { v30.h }[2], [x15], #0x2\n"
+    "tbz x0, #0, 51f\n"
+    "ld1 { v30.b }[6], [x15]\n"
     "b 51f\n"
     "48:"  // Oddments: Load (3, 1): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 51f\n"
-    "ld1 { v30.b }[4], [x21]\n"
+    "tbz x0, #0, 51f\n"
+    "ld1 { v30.b }[4], [x15]\n"
     "b 51f\n"
     "49:"  // Oddments: Load (3, 1): Bit 2: Unset
-    "tbz x4, #1, 50f\n"
-    "ld1 { v30.h }[0], [x21], #0x2\n"
-    "tbz x4, #0, 51f\n"
-    "ld1 { v30.b }[2], [x21]\n"
+    "tbz x0, #1, 50f\n"
+    "ld1 { v30.h }[0], [x15], #0x2\n"
+    "tbz x0, #0, 51f\n"
+    "ld1 { v30.b }[2], [x15]\n"
     "b 51f\n"
     "50:"  // Oddments: Load (3, 1): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 51f\n"
-    "ld1 { v30.b }[0], [x21]\n"
+    "tbz x0, #0, 51f\n"
+    "ld1 { v30.b }[0], [x15]\n"
     "51:"  // Oddments: Load (3, 1): Bit 2: End
-    "ldr d1, [x3, #0x58]\n"
-    "ssubl v30.8h, v30.8b, v7.8b\n"
-    "smlal v10.4s, v30.4h, v0.4h\n"
-    "ldr x14, [x25, #0xa0]\n"
-    "ssubl v1.8h, v1.8b, v13.8b\n"
-    "smlal2 v9.4s, v30.8h, v0.8h\n"
-    "add x14, x14, x10\n"
+    "ldr d1, [x23, #0x58]\n"
+    "ssubl v30.8h, v30.8b, v9.8b\n"
+    "ssubl v1.8h, v1.8b, v14.8b\n"
+    "ldr x21, [x20, #0xa0]\n"
+    "smlal v6.4s, v30.4h, v0.4h\n"
+    "smlal2 v5.4s, v30.8h, v0.8h\n"
+    "add x21, x21, x24\n"
     "smlal v15.4s, v25.4h, v1.4h\n"
-    "smlal2 v20.4s, v25.8h, v1.8h\n"
-    "smlal v18.4s, v24.4h, v1.4h\n"
-    "smlal2 v5.4s, v24.8h, v1.8h\n"
-    "smlal v11.4s, v30.4h, v1.4h\n"
-    "smlal2 v8.4s, v30.8h, v1.8h\n"
-    "tbz x4, #2, 53f\n"
-    "ld1 { v26.s }[0], [x14], #0x4\n"
-    "tbz x4, #1, 52f\n"
-    "ld1 { v26.h }[2], [x14], #0x2\n"
-    "tbz x4, #0, 55f\n"
-    "ld1 { v26.b }[6], [x14]\n"
+    "smlal2 v16.4s, v25.8h, v1.8h\n"
+    "smlal v17.4s, v24.4h, v1.4h\n"
+    "smlal2 v8.4s, v24.8h, v1.8h\n"
+    "smlal v10.4s, v30.4h, v1.4h\n"
+    "smlal2 v7.4s, v30.8h, v1.8h\n"
+    "tbz x0, #2, 53f\n"
+    "ld1 { v26.s }[0], [x21], #0x4\n"
+    "tbz x0, #1, 52f\n"
+    "ld1 { v26.h }[2], [x21], #0x2\n"
+    "tbz x0, #0, 55f\n"
+    "ld1 { v26.b }[6], [x21]\n"
     "b 55f\n"
     "52:"  // Oddments: Load (3, 2): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 55f\n"
-    "ld1 { v26.b }[4], [x14]\n"
+    "tbz x0, #0, 55f\n"
+    "ld1 { v26.b }[4], [x21]\n"
     "b 55f\n"
     "53:"  // Oddments: Load (3, 2): Bit 2: Unset
-    "tbz x4, #1, 54f\n"
-    "ld1 { v26.h }[0], [x14], #0x2\n"
-    "tbz x4, #0, 55f\n"
-    "ld1 { v26.b }[2], [x14]\n"
+    "tbz x0, #1, 54f\n"
+    "ld1 { v26.h }[0], [x21], #0x2\n"
+    "tbz x0, #0, 55f\n"
+    "ld1 { v26.b }[2], [x21]\n"
     "b 55f\n"
     "54:"  // Oddments: Load (3, 2): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 55f\n"
-    "ld1 { v26.b }[0], [x14]\n"
+    "tbz x0, #0, 55f\n"
+    "ld1 { v26.b }[0], [x21]\n"
     "55:"  // Oddments: Load (3, 2): Bit 2: End
-    "ldr d2, [x3, #0x60]\n"
-    "ssubl v26.8h, v26.8b, v7.8b\n"
-    "smlal v10.4s, v26.4h, v1.4h\n"
-    "ldr x11, [x25, #0xa8]\n"
-    "ssubl v2.8h, v2.8b, v13.8b\n"
-    "smlal2 v9.4s, v26.8h, v1.8h\n"
-    "add x11, x11, x10\n"
+    "ldr d2, [x23, #0x60]\n"
+    "ssubl v26.8h, v26.8b, v9.8b\n"
+    "ssubl v2.8h, v2.8b, v14.8b\n"
+    "ldr x2, [x20, #0xa8]\n"
+    "smlal v6.4s, v26.4h, v1.4h\n"
+    "smlal2 v5.4s, v26.8h, v1.8h\n"
+    "add x2, x2, x24\n"
     "smlal v15.4s, v24.4h, v2.4h\n"
-    "smlal2 v20.4s, v24.8h, v2.8h\n"
-    "smlal v18.4s, v27.4h, v2.4h\n"
-    "smlal2 v5.4s, v27.8h, v2.8h\n"
-    "smlal v11.4s, v26.4h, v2.4h\n"
-    "smlal2 v8.4s, v26.8h, v2.8h\n"
-    "tbz x4, #2, 57f\n"
-    "ld1 { v25.s }[0], [x11], #0x4\n"
-    "tbz x4, #1, 56f\n"
-    "ld1 { v25.h }[2], [x11], #0x2\n"
-    "tbz x4, #0, 59f\n"
-    "ld1 { v25.b }[6], [x11]\n"
+    "smlal2 v16.4s, v24.8h, v2.8h\n"
+    "smlal v17.4s, v27.4h, v2.4h\n"
+    "smlal2 v8.4s, v27.8h, v2.8h\n"
+    "smlal v10.4s, v26.4h, v2.4h\n"
+    "smlal2 v7.4s, v26.8h, v2.8h\n"
+    "tbz x0, #2, 57f\n"
+    "ld1 { v25.s }[0], [x2], #0x4\n"
+    "tbz x0, #1, 56f\n"
+    "ld1 { v25.h }[2], [x2], #0x2\n"
+    "tbz x0, #0, 59f\n"
+    "ld1 { v25.b }[6], [x2]\n"
     "b 59f\n"
     "56:"  // Oddments: Load (3, 3): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 59f\n"
-    "ld1 { v25.b }[4], [x11]\n"
+    "tbz x0, #0, 59f\n"
+    "ld1 { v25.b }[4], [x2]\n"
     "b 59f\n"
     "57:"  // Oddments: Load (3, 3): Bit 2: Unset
-    "tbz x4, #1, 58f\n"
-    "ld1 { v25.h }[0], [x11], #0x2\n"
-    "tbz x4, #0, 59f\n"
-    "ld1 { v25.b }[2], [x11]\n"
+    "tbz x0, #1, 58f\n"
+    "ld1 { v25.h }[0], [x2], #0x2\n"
+    "tbz x0, #0, 59f\n"
+    "ld1 { v25.b }[2], [x2]\n"
     "b 59f\n"
     "58:"  // Oddments: Load (3, 3): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 59f\n"
-    "ld1 { v25.b }[0], [x11]\n"
+    "tbz x0, #0, 59f\n"
+    "ld1 { v25.b }[0], [x2]\n"
     "59:"  // Oddments: Load (3, 3): Bit 2: End
-    "ldr d3, [x3, #0x68]\n"
-    "ssubl v25.8h, v25.8b, v7.8b\n"
-    "smlal v10.4s, v25.4h, v2.4h\n"
-    "ldr x24, [x25, #0xb0]\n"
-    "ssubl v3.8h, v3.8b, v13.8b\n"
-    "smlal2 v9.4s, v25.8h, v2.8h\n"
-    "add x24, x24, x10\n"
+    "ldr d3, [x23, #0x68]\n"
+    "ssubl v25.8h, v25.8b, v9.8b\n"
+    "ssubl v3.8h, v3.8b, v14.8b\n"
+    "ldr x13, [x20, #0xb0]\n"
+    "smlal v6.4s, v25.4h, v2.4h\n"
+    "smlal2 v5.4s, v25.8h, v2.8h\n"
+    "add x13, x13, x24\n"
     "smlal v15.4s, v27.4h, v3.4h\n"
-    "smlal2 v20.4s, v27.8h, v3.8h\n"
-    "smlal v18.4s, v23.4h, v3.4h\n"
-    "smlal2 v5.4s, v23.8h, v3.8h\n"
-    "smlal v11.4s, v25.4h, v3.4h\n"
-    "smlal2 v8.4s, v25.8h, v3.8h\n"
-    "tbz x4, #2, 61f\n"
-    "ld1 { v24.s }[0], [x24], #0x4\n"
-    "tbz x4, #1, 60f\n"
-    "ld1 { v24.h }[2], [x24], #0x2\n"
-    "tbz x4, #0, 63f\n"
-    "ld1 { v24.b }[6], [x24]\n"
+    "smlal2 v16.4s, v27.8h, v3.8h\n"
+    "smlal v17.4s, v23.4h, v3.4h\n"
+    "smlal2 v8.4s, v23.8h, v3.8h\n"
+    "smlal v10.4s, v25.4h, v3.4h\n"
+    "smlal2 v7.4s, v25.8h, v3.8h\n"
+    "tbz x0, #2, 61f\n"
+    "ld1 { v24.s }[0], [x13], #0x4\n"
+    "tbz x0, #1, 60f\n"
+    "ld1 { v24.h }[2], [x13], #0x2\n"
+    "tbz x0, #0, 63f\n"
+    "ld1 { v24.b }[6], [x13]\n"
     "b 63f\n"
     "60:"  // Oddments: Load (3, 4): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 63f\n"
-    "ld1 { v24.b }[4], [x24]\n"
+    "tbz x0, #0, 63f\n"
+    "ld1 { v24.b }[4], [x13]\n"
     "b 63f\n"
     "61:"  // Oddments: Load (3, 4): Bit 2: Unset
-    "tbz x4, #1, 62f\n"
-    "ld1 { v24.h }[0], [x24], #0x2\n"
-    "tbz x4, #0, 63f\n"
-    "ld1 { v24.b }[2], [x24]\n"
+    "tbz x0, #1, 62f\n"
+    "ld1 { v24.h }[0], [x13], #0x2\n"
+    "tbz x0, #0, 63f\n"
+    "ld1 { v24.b }[2], [x13]\n"
     "b 63f\n"
     "62:"  // Oddments: Load (3, 4): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 63f\n"
-    "ld1 { v24.b }[0], [x24]\n"
+    "tbz x0, #0, 63f\n"
+    "ld1 { v24.b }[0], [x13]\n"
     "63:"  // Oddments: Load (3, 4): Bit 2: End
-    "ldr d4, [x3, #0x70]\n"
-    "ssubl v24.8h, v24.8b, v7.8b\n"
-    "smlal v10.4s, v24.4h, v3.4h\n"
-    "ldr x0, [x25, #0xb8]\n"
-    "ssubl v4.8h, v4.8b, v13.8b\n"
-    "smlal2 v9.4s, v24.8h, v3.8h\n"
-    "add x0, x0, x10\n"
+    "ldr d4, [x23, #0x70]\n"
+    "ssubl v24.8h, v24.8b, v9.8b\n"
+    "ssubl v4.8h, v4.8b, v14.8b\n"
+    "ldr x9, [x20, #0xb8]\n"
+    "smlal v6.4s, v24.4h, v3.4h\n"
+    "smlal2 v5.4s, v24.8h, v3.8h\n"
+    "add x9, x9, x24\n"
     "smlal v15.4s, v23.4h, v4.4h\n"
-    "smlal2 v20.4s, v23.8h, v4.8h\n"
-    "smlal v18.4s, v28.4h, v4.4h\n"
-    "smlal2 v5.4s, v28.8h, v4.8h\n"
-    "smlal v11.4s, v24.4h, v4.4h\n"
-    "smlal2 v8.4s, v24.8h, v4.8h\n"
-    "tbz x4, #2, 65f\n"
-    "ld1 { v22.s }[0], [x0], #0x4\n"
-    "tbz x4, #1, 64f\n"
-    "ld1 { v22.h }[2], [x0], #0x2\n"
-    "tbz x4, #0, 67f\n"
-    "ld1 { v22.b }[6], [x0]\n"
+    "smlal2 v16.4s, v23.8h, v4.8h\n"
+    "smlal v17.4s, v28.4h, v4.4h\n"
+    "smlal2 v8.4s, v28.8h, v4.8h\n"
+    "smlal v10.4s, v24.4h, v4.4h\n"
+    "smlal2 v7.4s, v24.8h, v4.8h\n"
+    "tbz x0, #2, 65f\n"
+    "ld1 { v22.s }[0], [x9], #0x4\n"
+    "tbz x0, #1, 64f\n"
+    "ld1 { v22.h }[2], [x9], #0x2\n"
+    "tbz x0, #0, 67f\n"
+    "ld1 { v22.b }[6], [x9]\n"
     "b 67f\n"
     "64:"  // Oddments: Load (3, 5): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 67f\n"
-    "ld1 { v22.b }[4], [x0]\n"
+    "tbz x0, #0, 67f\n"
+    "ld1 { v22.b }[4], [x9]\n"
     "b 67f\n"
     "65:"  // Oddments: Load (3, 5): Bit 2: Unset
-    "tbz x4, #1, 66f\n"
-    "ld1 { v22.h }[0], [x0], #0x2\n"
-    "tbz x4, #0, 67f\n"
-    "ld1 { v22.b }[2], [x0]\n"
+    "tbz x0, #1, 66f\n"
+    "ld1 { v22.h }[0], [x9], #0x2\n"
+    "tbz x0, #0, 67f\n"
+    "ld1 { v22.b }[2], [x9]\n"
     "b 67f\n"
     "66:"  // Oddments: Load (3, 5): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 67f\n"
-    "ld1 { v22.b }[0], [x0]\n"
+    "tbz x0, #0, 67f\n"
+    "ld1 { v22.b }[0], [x9]\n"
     "67:"  // Oddments: Load (3, 5): Bit 2: End
-    "ldr d0, [x3, #0x78]\n"
-    "ssubl v22.8h, v22.8b, v7.8b\n"
-    "smlal v10.4s, v22.4h, v4.4h\n"
-    "ldr x15, [x25, #0xc0]\n"
-    "ssubl v0.8h, v0.8b, v13.8b\n"
-    "smlal2 v9.4s, v22.8h, v4.8h\n"
-    "add x15, x15, x10\n"
+    "ldr d0, [x23, #0x78]\n"
+    "ssubl v22.8h, v22.8b, v9.8b\n"
+    "ssubl v0.8h, v0.8b, v14.8b\n"
+    "ldr x19, [x20, #0xc0]\n"
+    "smlal v6.4s, v22.4h, v4.4h\n"
+    "smlal2 v5.4s, v22.8h, v4.8h\n"
+    "add x19, x19, x24\n"
     "smlal v15.4s, v31.4h, v0.4h\n"
-    "smlal2 v20.4s, v31.8h, v0.8h\n"
-    "smlal v18.4s, v30.4h, v0.4h\n"
-    "smlal2 v5.4s, v30.8h, v0.8h\n"
-    "tbz x4, #2, 69f\n"
-    "ld1 { v27.s }[0], [x15], #0x4\n"
-    "tbz x4, #1, 68f\n"
-    "ld1 { v27.h }[2], [x15], #0x2\n"
-    "tbz x4, #0, 71f\n"
-    "ld1 { v27.b }[6], [x15]\n"
+    "smlal2 v16.4s, v31.8h, v0.8h\n"
+    "smlal v17.4s, v30.4h, v0.4h\n"
+    "smlal2 v8.4s, v30.8h, v0.8h\n"
+    "tbz x0, #2, 69f\n"
+    "ld1 { v27.s }[0], [x19], #0x4\n"
+    "tbz x0, #1, 68f\n"
+    "ld1 { v27.h }[2], [x19], #0x2\n"
+    "tbz x0, #0, 71f\n"
+    "ld1 { v27.b }[6], [x19]\n"
     "b 71f\n"
     "68:"  // Oddments: Load (4, 0): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 71f\n"
-    "ld1 { v27.b }[4], [x15]\n"
+    "tbz x0, #0, 71f\n"
+    "ld1 { v27.b }[4], [x19]\n"
     "b 71f\n"
     "69:"  // Oddments: Load (4, 0): Bit 2: Unset
-    "tbz x4, #1, 70f\n"
-    "ld1 { v27.h }[0], [x15], #0x2\n"
-    "tbz x4, #0, 71f\n"
-    "ld1 { v27.b }[2], [x15]\n"
+    "tbz x0, #1, 70f\n"
+    "ld1 { v27.h }[0], [x19], #0x2\n"
+    "tbz x0, #0, 71f\n"
+    "ld1 { v27.b }[2], [x19]\n"
     "b 71f\n"
     "70:"  // Oddments: Load (4, 0): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 71f\n"
-    "ld1 { v27.b }[0], [x15]\n"
+    "tbz x0, #0, 71f\n"
+    "ld1 { v27.b }[0], [x19]\n"
     "71:"  // Oddments: Load (4, 0): Bit 2: End
-    "ldr x9, [x25, #0xc8]\n"
-    "ssubl v27.8h, v27.8b, v7.8b\n"
-    "smlal v11.4s, v27.4h, v0.4h\n"
-    "smlal2 v8.4s, v27.8h, v0.8h\n"
-    "add x9, x9, x10\n"
-    "tbz x4, #2, 73f\n"
-    "ld1 { v23.s }[0], [x9], #0x4\n"
-    "tbz x4, #1, 72f\n"
-    "ld1 { v23.h }[2], [x9], #0x2\n"
-    "tbz x4, #0, 75f\n"
-    "ld1 { v23.b }[6], [x9]\n"
+    "ssubl v27.8h, v27.8b, v9.8b\n"
+    "ldr x28, [x20, #0xc8]\n"
+    "smlal v10.4s, v27.4h, v0.4h\n"
+    "smlal2 v7.4s, v27.8h, v0.8h\n"
+    "add x28, x28, x24\n"
+    "tbz x0, #2, 73f\n"
+    "ld1 { v23.s }[0], [x28], #0x4\n"
+    "tbz x0, #1, 72f\n"
+    "ld1 { v23.h }[2], [x28], #0x2\n"
+    "tbz x0, #0, 75f\n"
+    "ld1 { v23.b }[6], [x28]\n"
     "b 75f\n"
     "72:"  // Oddments: Load (4, 1): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 75f\n"
-    "ld1 { v23.b }[4], [x9]\n"
+    "tbz x0, #0, 75f\n"
+    "ld1 { v23.b }[4], [x28]\n"
     "b 75f\n"
     "73:"  // Oddments: Load (4, 1): Bit 2: Unset
-    "tbz x4, #1, 74f\n"
-    "ld1 { v23.h }[0], [x9], #0x2\n"
-    "tbz x4, #0, 75f\n"
-    "ld1 { v23.b }[2], [x9]\n"
+    "tbz x0, #1, 74f\n"
+    "ld1 { v23.h }[0], [x28], #0x2\n"
+    "tbz x0, #0, 75f\n"
+    "ld1 { v23.b }[2], [x28]\n"
     "b 75f\n"
     "74:"  // Oddments: Load (4, 1): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 75f\n"
-    "ld1 { v23.b }[0], [x9]\n"
+    "tbz x0, #0, 75f\n"
+    "ld1 { v23.b }[0], [x28]\n"
     "75:"  // Oddments: Load (4, 1): Bit 2: End
-    "ldr d1, [x3, #0x80]\n"
-    "ssubl v23.8h, v23.8b, v7.8b\n"
-    "smlal v10.4s, v23.4h, v0.4h\n"
-    "ldr x27, [x25, #0xd0]\n"
-    "ssubl v1.8h, v1.8b, v13.8b\n"
-    "smlal2 v9.4s, v23.8h, v0.8h\n"
-    "add x27, x27, x10\n"
+    "ldr d1, [x23, #0x80]\n"
+    "ssubl v23.8h, v23.8b, v9.8b\n"
+    "ssubl v1.8h, v1.8b, v14.8b\n"
+    "ldr x6, [x20, #0xd0]\n"
+    "smlal v6.4s, v23.4h, v0.4h\n"
+    "smlal2 v5.4s, v23.8h, v0.8h\n"
+    "add x6, x6, x24\n"
     "smlal v15.4s, v30.4h, v1.4h\n"
-    "smlal2 v20.4s, v30.8h, v1.8h\n"
-    "smlal v18.4s, v26.4h, v1.4h\n"
-    "smlal2 v5.4s, v26.8h, v1.8h\n"
-    "smlal v11.4s, v23.4h, v1.4h\n"
-    "smlal2 v8.4s, v23.8h, v1.8h\n"
-    "tbz x4, #2, 77f\n"
-    "ld1 { v31.s }[0], [x27], #0x4\n"
-    "tbz x4, #1, 76f\n"
-    "ld1 { v31.h }[2], [x27], #0x2\n"
-    "tbz x4, #0, 79f\n"
-    "ld1 { v31.b }[6], [x27]\n"
+    "smlal2 v16.4s, v30.8h, v1.8h\n"
+    "smlal v17.4s, v26.4h, v1.4h\n"
+    "smlal2 v8.4s, v26.8h, v1.8h\n"
+    "smlal v10.4s, v23.4h, v1.4h\n"
+    "smlal2 v7.4s, v23.8h, v1.8h\n"
+    "tbz x0, #2, 77f\n"
+    "ld1 { v31.s }[0], [x6], #0x4\n"
+    "tbz x0, #1, 76f\n"
+    "ld1 { v31.h }[2], [x6], #0x2\n"
+    "tbz x0, #0, 79f\n"
+    "ld1 { v31.b }[6], [x6]\n"
     "b 79f\n"
     "76:"  // Oddments: Load (4, 2): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 79f\n"
-    "ld1 { v31.b }[4], [x27]\n"
+    "tbz x0, #0, 79f\n"
+    "ld1 { v31.b }[4], [x6]\n"
     "b 79f\n"
     "77:"  // Oddments: Load (4, 2): Bit 2: Unset
-    "tbz x4, #1, 78f\n"
-    "ld1 { v31.h }[0], [x27], #0x2\n"
-    "tbz x4, #0, 79f\n"
-    "ld1 { v31.b }[2], [x27]\n"
+    "tbz x0, #1, 78f\n"
+    "ld1 { v31.h }[0], [x6], #0x2\n"
+    "tbz x0, #0, 79f\n"
+    "ld1 { v31.b }[2], [x6]\n"
     "b 79f\n"
     "78:"  // Oddments: Load (4, 2): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 79f\n"
-    "ld1 { v31.b }[0], [x27]\n"
+    "tbz x0, #0, 79f\n"
+    "ld1 { v31.b }[0], [x6]\n"
     "79:"  // Oddments: Load (4, 2): Bit 2: End
-    "ldr d2, [x3, #0x88]\n"
-    "ssubl v31.8h, v31.8b, v7.8b\n"
-    "smlal v10.4s, v31.4h, v1.4h\n"
-    "ldr x28, [x25, #0xd8]\n"
-    "ssubl v2.8h, v2.8b, v13.8b\n"
-    "smlal2 v9.4s, v31.8h, v1.8h\n"
-    "add x28, x28, x10\n"
+    "ldr d2, [x23, #0x88]\n"
+    "ssubl v31.8h, v31.8b, v9.8b\n"
+    "ssubl v2.8h, v2.8b, v14.8b\n"
+    "ldr x27, [x20, #0xd8]\n"
+    "smlal v6.4s, v31.4h, v1.4h\n"
+    "smlal2 v5.4s, v31.8h, v1.8h\n"
+    "add x27, x27, x24\n"
     "smlal v15.4s, v26.4h, v2.4h\n"
-    "smlal2 v20.4s, v26.8h, v2.8h\n"
-    "smlal v18.4s, v25.4h, v2.4h\n"
-    "smlal2 v5.4s, v25.8h, v2.8h\n"
-    "smlal v11.4s, v31.4h, v2.4h\n"
-    "smlal2 v8.4s, v31.8h, v2.8h\n"
-    "tbz x4, #2, 81f\n"
-    "ld1 { v30.s }[0], [x28], #0x4\n"
-    "tbz x4, #1, 80f\n"
-    "ld1 { v30.h }[2], [x28], #0x2\n"
-    "tbz x4, #0, 83f\n"
-    "ld1 { v30.b }[6], [x28]\n"
+    "smlal2 v16.4s, v26.8h, v2.8h\n"
+    "smlal v17.4s, v25.4h, v2.4h\n"
+    "smlal2 v8.4s, v25.8h, v2.8h\n"
+    "smlal v10.4s, v31.4h, v2.4h\n"
+    "smlal2 v7.4s, v31.8h, v2.8h\n"
+    "tbz x0, #2, 81f\n"
+    "ld1 { v30.s }[0], [x27], #0x4\n"
+    "tbz x0, #1, 80f\n"
+    "ld1 { v30.h }[2], [x27], #0x2\n"
+    "tbz x0, #0, 83f\n"
+    "ld1 { v30.b }[6], [x27]\n"
     "b 83f\n"
     "80:"  // Oddments: Load (4, 3): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 83f\n"
-    "ld1 { v30.b }[4], [x28]\n"
+    "tbz x0, #0, 83f\n"
+    "ld1 { v30.b }[4], [x27]\n"
     "b 83f\n"
     "81:"  // Oddments: Load (4, 3): Bit 2: Unset
-    "tbz x4, #1, 82f\n"
-    "ld1 { v30.h }[0], [x28], #0x2\n"
-    "tbz x4, #0, 83f\n"
-    "ld1 { v30.b }[2], [x28]\n"
+    "tbz x0, #1, 82f\n"
+    "ld1 { v30.h }[0], [x27], #0x2\n"
+    "tbz x0, #0, 83f\n"
+    "ld1 { v30.b }[2], [x27]\n"
     "b 83f\n"
     "82:"  // Oddments: Load (4, 3): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 83f\n"
-    "ld1 { v30.b }[0], [x28]\n"
+    "tbz x0, #0, 83f\n"
+    "ld1 { v30.b }[0], [x27]\n"
     "83:"  // Oddments: Load (4, 3): Bit 2: End
-    "ldr d3, [x3, #0x90]\n"
-    "ssubl v30.8h, v30.8b, v7.8b\n"
-    "smlal v10.4s, v30.4h, v2.4h\n"
-    "ldr x12, [x25, #0xe0]\n"
-    "ssubl v3.8h, v3.8b, v13.8b\n"
-    "smlal2 v9.4s, v30.8h, v2.8h\n"
-    "add x12, x12, x10\n"
+    "ldr d3, [x23, #0x90]\n"
+    "ssubl v30.8h, v30.8b, v9.8b\n"
+    "ssubl v3.8h, v3.8b, v14.8b\n"
+    "ldr x11, [x20, #0xe0]\n"
+    "smlal v6.4s, v30.4h, v2.4h\n"
+    "smlal2 v5.4s, v30.8h, v2.8h\n"
+    "add x11, x11, x24\n"
     "smlal v15.4s, v25.4h, v3.4h\n"
-    "smlal2 v20.4s, v25.8h, v3.8h\n"
-    "smlal v18.4s, v24.4h, v3.4h\n"
-    "smlal2 v5.4s, v24.8h, v3.8h\n"
-    "smlal v11.4s, v30.4h, v3.4h\n"
-    "smlal2 v8.4s, v30.8h, v3.8h\n"
-    "tbz x4, #2, 85f\n"
-    "ld1 { v28.s }[0], [x12], #0x4\n"
-    "tbz x4, #1, 84f\n"
-    "ld1 { v28.h }[2], [x12], #0x2\n"
-    "tbz x4, #0, 87f\n"
-    "ld1 { v28.b }[6], [x12]\n"
+    "smlal2 v16.4s, v25.8h, v3.8h\n"
+    "smlal v17.4s, v24.4h, v3.4h\n"
+    "smlal2 v8.4s, v24.8h, v3.8h\n"
+    "smlal v10.4s, v30.4h, v3.4h\n"
+    "smlal2 v7.4s, v30.8h, v3.8h\n"
+    "tbz x0, #2, 85f\n"
+    "ld1 { v28.s }[0], [x11], #0x4\n"
+    "tbz x0, #1, 84f\n"
+    "ld1 { v28.h }[2], [x11], #0x2\n"
+    "tbz x0, #0, 87f\n"
+    "ld1 { v28.b }[6], [x11]\n"
     "b 87f\n"
     "84:"  // Oddments: Load (4, 4): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 87f\n"
-    "ld1 { v28.b }[4], [x12]\n"
+    "tbz x0, #0, 87f\n"
+    "ld1 { v28.b }[4], [x11]\n"
     "b 87f\n"
     "85:"  // Oddments: Load (4, 4): Bit 2: Unset
-    "tbz x4, #1, 86f\n"
-    "ld1 { v28.h }[0], [x12], #0x2\n"
-    "tbz x4, #0, 87f\n"
-    "ld1 { v28.b }[2], [x12]\n"
+    "tbz x0, #1, 86f\n"
+    "ld1 { v28.h }[0], [x11], #0x2\n"
+    "tbz x0, #0, 87f\n"
+    "ld1 { v28.b }[2], [x11]\n"
     "b 87f\n"
     "86:"  // Oddments: Load (4, 4): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 87f\n"
-    "ld1 { v28.b }[0], [x12]\n"
+    "tbz x0, #0, 87f\n"
+    "ld1 { v28.b }[0], [x11]\n"
     "87:"  // Oddments: Load (4, 4): Bit 2: End
-    "ldr d4, [x3, #0x98]\n"
-    "ssubl v28.8h, v28.8b, v7.8b\n"
-    "smlal v10.4s, v28.4h, v3.4h\n"
-    "ldr x7, [x25, #0xe8]\n"
-    "ssubl v4.8h, v4.8b, v13.8b\n"
-    "smlal2 v9.4s, v28.8h, v3.8h\n"
-    "add x7, x7, x10\n"
+    "ldr d4, [x23, #0x98]\n"
+    "ssubl v28.8h, v28.8b, v9.8b\n"
+    "ssubl v4.8h, v4.8b, v14.8b\n"
+    "ldr x17, [x20, #0xe8]\n"
+    "smlal v6.4s, v28.4h, v3.4h\n"
+    "smlal2 v5.4s, v28.8h, v3.8h\n"
+    "add x17, x17, x24\n"
     "smlal v15.4s, v24.4h, v4.4h\n"
-    "smlal2 v20.4s, v24.8h, v4.8h\n"
-    "smlal v18.4s, v22.4h, v4.4h\n"
-    "smlal2 v5.4s, v22.8h, v4.8h\n"
-    "smlal v11.4s, v28.4h, v4.4h\n"
-    "smlal2 v8.4s, v28.8h, v4.8h\n"
-    "tbz x4, #2, 89f\n"
-    "ld1 { v26.s }[0], [x7], #0x4\n"
-    "tbz x4, #1, 88f\n"
-    "ld1 { v26.h }[2], [x7], #0x2\n"
-    "tbz x4, #0, 91f\n"
-    "ld1 { v26.b }[6], [x7]\n"
+    "smlal2 v16.4s, v24.8h, v4.8h\n"
+    "smlal v17.4s, v22.4h, v4.4h\n"
+    "smlal2 v8.4s, v22.8h, v4.8h\n"
+    "smlal v10.4s, v28.4h, v4.4h\n"
+    "smlal2 v7.4s, v28.8h, v4.8h\n"
+    "tbz x0, #2, 89f\n"
+    "ld1 { v26.s }[0], [x17], #0x4\n"
+    "tbz x0, #1, 88f\n"
+    "ld1 { v26.h }[2], [x17], #0x2\n"
+    "tbz x0, #0, 91f\n"
+    "ld1 { v26.b }[6], [x17]\n"
     "b 91f\n"
     "88:"  // Oddments: Load (4, 5): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 91f\n"
-    "ld1 { v26.b }[4], [x7]\n"
+    "tbz x0, #0, 91f\n"
+    "ld1 { v26.b }[4], [x17]\n"
     "b 91f\n"
     "89:"  // Oddments: Load (4, 5): Bit 2: Unset
-    "tbz x4, #1, 90f\n"
-    "ld1 { v26.h }[0], [x7], #0x2\n"
-    "tbz x4, #0, 91f\n"
-    "ld1 { v26.b }[2], [x7]\n"
+    "tbz x0, #1, 90f\n"
+    "ld1 { v26.h }[0], [x17], #0x2\n"
+    "tbz x0, #0, 91f\n"
+    "ld1 { v26.b }[2], [x17]\n"
     "b 91f\n"
     "90:"  // Oddments: Load (4, 5): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 91f\n"
-    "ld1 { v26.b }[0], [x7]\n"
+    "tbz x0, #0, 91f\n"
+    "ld1 { v26.b }[0], [x17]\n"
     "91:"  // Oddments: Load (4, 5): Bit 2: End
-    "ldr d0, [x3, #0xa0]\n"
-    "ssubl v26.8h, v26.8b, v7.8b\n"
-    "smlal v10.4s, v26.4h, v4.4h\n"
-    "ldr x26, [x25, #0xf0]\n"
-    "ssubl v0.8h, v0.8b, v13.8b\n"
-    "smlal2 v9.4s, v26.8h, v4.8h\n"
-    "add x26, x26, x10\n"
+    "ldr d0, [x23, #0xa0]\n"
+    "ssubl v26.8h, v26.8b, v9.8b\n"
+    "ssubl v0.8h, v0.8b, v14.8b\n"
+    "ldr x5, [x20, #0xf0]\n"
+    "smlal v6.4s, v26.4h, v4.4h\n"
+    "smlal2 v5.4s, v26.8h, v4.8h\n"
+    "add x5, x5, x24\n"
     "smlal v15.4s, v27.4h, v0.4h\n"
-    "smlal2 v20.4s, v27.8h, v0.8h\n"
-    "smlal v18.4s, v23.4h, v0.4h\n"
-    "smlal2 v5.4s, v23.8h, v0.8h\n"
-    "tbz x4, #2, 93f\n"
-    "ld1 { v25.s }[0], [x26], #0x4\n"
-    "tbz x4, #1, 92f\n"
-    "ld1 { v25.h }[2], [x26], #0x2\n"
-    "tbz x4, #0, 95f\n"
-    "ld1 { v25.b }[6], [x26]\n"
+    "smlal2 v16.4s, v27.8h, v0.8h\n"
+    "smlal v17.4s, v23.4h, v0.4h\n"
+    "smlal2 v8.4s, v23.8h, v0.8h\n"
+    "tbz x0, #2, 93f\n"
+    "ld1 { v25.s }[0], [x5], #0x4\n"
+    "tbz x0, #1, 92f\n"
+    "ld1 { v25.h }[2], [x5], #0x2\n"
+    "tbz x0, #0, 95f\n"
+    "ld1 { v25.b }[6], [x5]\n"
     "b 95f\n"
     "92:"  // Oddments: Load (5, 0): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 95f\n"
-    "ld1 { v25.b }[4], [x26]\n"
+    "tbz x0, #0, 95f\n"
+    "ld1 { v25.b }[4], [x5]\n"
     "b 95f\n"
     "93:"  // Oddments: Load (5, 0): Bit 2: Unset
-    "tbz x4, #1, 94f\n"
-    "ld1 { v25.h }[0], [x26], #0x2\n"
-    "tbz x4, #0, 95f\n"
-    "ld1 { v25.b }[2], [x26]\n"
+    "tbz x0, #1, 94f\n"
+    "ld1 { v25.h }[0], [x5], #0x2\n"
+    "tbz x0, #0, 95f\n"
+    "ld1 { v25.b }[2], [x5]\n"
     "b 95f\n"
     "94:"  // Oddments: Load (5, 0): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 95f\n"
-    "ld1 { v25.b }[0], [x26]\n"
+    "tbz x0, #0, 95f\n"
+    "ld1 { v25.b }[0], [x5]\n"
     "95:"  // Oddments: Load (5, 0): Bit 2: End
-    "ldr x23, [x25, #0xf8]\n"
-    "ssubl v25.8h, v25.8b, v7.8b\n"
-    "smlal v11.4s, v25.4h, v0.4h\n"
-    "smlal2 v8.4s, v25.8h, v0.8h\n"
-    "add x23, x23, x10\n"
-    "tbz x4, #2, 97f\n"
-    "ld1 { v24.s }[0], [x23], #0x4\n"
-    "tbz x4, #1, 96f\n"
-    "ld1 { v24.h }[2], [x23], #0x2\n"
-    "tbz x4, #0, 99f\n"
-    "ld1 { v24.b }[6], [x23]\n"
+    "ssubl v25.8h, v25.8b, v9.8b\n"
+    "ldr x25, [x20, #0xf8]\n"
+    "smlal v10.4s, v25.4h, v0.4h\n"
+    "smlal2 v7.4s, v25.8h, v0.8h\n"
+    "add x25, x25, x24\n"
+    "tbz x0, #2, 97f\n"
+    "ld1 { v24.s }[0], [x25], #0x4\n"
+    "tbz x0, #1, 96f\n"
+    "ld1 { v24.h }[2], [x25], #0x2\n"
+    "tbz x0, #0, 99f\n"
+    "ld1 { v24.b }[6], [x25]\n"
     "b 99f\n"
     "96:"  // Oddments: Load (5, 1): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 99f\n"
-    "ld1 { v24.b }[4], [x23]\n"
+    "tbz x0, #0, 99f\n"
+    "ld1 { v24.b }[4], [x25]\n"
     "b 99f\n"
     "97:"  // Oddments: Load (5, 1): Bit 2: Unset
-    "tbz x4, #1, 98f\n"
-    "ld1 { v24.h }[0], [x23], #0x2\n"
-    "tbz x4, #0, 99f\n"
-    "ld1 { v24.b }[2], [x23]\n"
+    "tbz x0, #1, 98f\n"
+    "ld1 { v24.h }[0], [x25], #0x2\n"
+    "tbz x0, #0, 99f\n"
+    "ld1 { v24.b }[2], [x25]\n"
     "b 99f\n"
     "98:"  // Oddments: Load (5, 1): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 99f\n"
-    "ld1 { v24.b }[0], [x23]\n"
+    "tbz x0, #0, 99f\n"
+    "ld1 { v24.b }[0], [x25]\n"
     "99:"  // Oddments: Load (5, 1): Bit 2: End
-    "ldr d1, [x3, #0xa8]\n"
-    "ssubl v24.8h, v24.8b, v7.8b\n"
-    "smlal v10.4s, v24.4h, v0.4h\n"
-    "ldr x22, [x25, #0x100]\n"
-    "ssubl v1.8h, v1.8b, v13.8b\n"
-    "smlal2 v9.4s, v24.8h, v0.8h\n"
-    "add x22, x22, x10\n"
+    "ldr d1, [x23, #0xa8]\n"
+    "ssubl v24.8h, v24.8b, v9.8b\n"
+    "ssubl v1.8h, v1.8b, v14.8b\n"
+    "ldr x26, [x20, #0x100]\n"
+    "smlal v6.4s, v24.4h, v0.4h\n"
+    "smlal2 v5.4s, v24.8h, v0.8h\n"
+    "add x26, x26, x24\n"
     "smlal v15.4s, v23.4h, v1.4h\n"
-    "smlal2 v20.4s, v23.8h, v1.8h\n"
-    "smlal v18.4s, v31.4h, v1.4h\n"
-    "smlal2 v5.4s, v31.8h, v1.8h\n"
-    "smlal v11.4s, v24.4h, v1.4h\n"
-    "smlal2 v8.4s, v24.8h, v1.8h\n"
-    "tbz x4, #2, 101f\n"
-    "ld1 { v27.s }[0], [x22], #0x4\n"
-    "tbz x4, #1, 100f\n"
-    "ld1 { v27.h }[2], [x22], #0x2\n"
-    "tbz x4, #0, 103f\n"
-    "ld1 { v27.b }[6], [x22]\n"
+    "smlal2 v16.4s, v23.8h, v1.8h\n"
+    "smlal v17.4s, v31.4h, v1.4h\n"
+    "smlal2 v8.4s, v31.8h, v1.8h\n"
+    "smlal v10.4s, v24.4h, v1.4h\n"
+    "smlal2 v7.4s, v24.8h, v1.8h\n"
+    "tbz x0, #2, 101f\n"
+    "ld1 { v27.s }[0], [x26], #0x4\n"
+    "tbz x0, #1, 100f\n"
+    "ld1 { v27.h }[2], [x26], #0x2\n"
+    "tbz x0, #0, 103f\n"
+    "ld1 { v27.b }[6], [x26]\n"
     "b 103f\n"
     "100:"  // Oddments: Load (5, 2): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 103f\n"
-    "ld1 { v27.b }[4], [x22]\n"
+    "tbz x0, #0, 103f\n"
+    "ld1 { v27.b }[4], [x26]\n"
     "b 103f\n"
     "101:"  // Oddments: Load (5, 2): Bit 2: Unset
-    "tbz x4, #1, 102f\n"
-    "ld1 { v27.h }[0], [x22], #0x2\n"
-    "tbz x4, #0, 103f\n"
-    "ld1 { v27.b }[2], [x22]\n"
+    "tbz x0, #1, 102f\n"
+    "ld1 { v27.h }[0], [x26], #0x2\n"
+    "tbz x0, #0, 103f\n"
+    "ld1 { v27.b }[2], [x26]\n"
     "b 103f\n"
     "102:"  // Oddments: Load (5, 2): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 103f\n"
-    "ld1 { v27.b }[0], [x22]\n"
+    "tbz x0, #0, 103f\n"
+    "ld1 { v27.b }[0], [x26]\n"
     "103:"  // Oddments: Load (5, 2): Bit 2: End
-    "ldr d2, [x3, #0xb0]\n"
-    "ssubl v27.8h, v27.8b, v7.8b\n"
-    "smlal v10.4s, v27.4h, v1.4h\n"
-    "ldr x20, [x25, #0x108]\n"
-    "ssubl v2.8h, v2.8b, v13.8b\n"
-    "smlal2 v9.4s, v27.8h, v1.8h\n"
-    "add x20, x20, x10\n"
+    "ldr d2, [x23, #0xb0]\n"
+    "ssubl v27.8h, v27.8b, v9.8b\n"
+    "ssubl v2.8h, v2.8b, v14.8b\n"
+    "ldr x12, [x20, #0x108]\n"
+    "smlal v6.4s, v27.4h, v1.4h\n"
+    "smlal2 v5.4s, v27.8h, v1.8h\n"
+    "add x12, x12, x24\n"
     "smlal v15.4s, v31.4h, v2.4h\n"
-    "smlal2 v20.4s, v31.8h, v2.8h\n"
-    "smlal v18.4s, v30.4h, v2.4h\n"
-    "smlal2 v5.4s, v30.8h, v2.8h\n"
-    "smlal v11.4s, v27.4h, v2.4h\n"
-    "smlal2 v8.4s, v27.8h, v2.8h\n"
-    "tbz x4, #2, 105f\n"
-    "ld1 { v25.s }[0], [x20], #0x4\n"
-    "tbz x4, #1, 104f\n"
-    "ld1 { v25.h }[2], [x20], #0x2\n"
-    "tbz x4, #0, 107f\n"
-    "ld1 { v25.b }[6], [x20]\n"
+    "smlal2 v16.4s, v31.8h, v2.8h\n"
+    "smlal v17.4s, v30.4h, v2.4h\n"
+    "smlal2 v8.4s, v30.8h, v2.8h\n"
+    "smlal v10.4s, v27.4h, v2.4h\n"
+    "smlal2 v7.4s, v27.8h, v2.8h\n"
+    "tbz x0, #2, 105f\n"
+    "ld1 { v25.s }[0], [x12], #0x4\n"
+    "tbz x0, #1, 104f\n"
+    "ld1 { v25.h }[2], [x12], #0x2\n"
+    "tbz x0, #0, 107f\n"
+    "ld1 { v25.b }[6], [x12]\n"
     "b 107f\n"
     "104:"  // Oddments: Load (5, 3): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 107f\n"
-    "ld1 { v25.b }[4], [x20]\n"
+    "tbz x0, #0, 107f\n"
+    "ld1 { v25.b }[4], [x12]\n"
     "b 107f\n"
     "105:"  // Oddments: Load (5, 3): Bit 2: Unset
-    "tbz x4, #1, 106f\n"
-    "ld1 { v25.h }[0], [x20], #0x2\n"
-    "tbz x4, #0, 107f\n"
-    "ld1 { v25.b }[2], [x20]\n"
+    "tbz x0, #1, 106f\n"
+    "ld1 { v25.h }[0], [x12], #0x2\n"
+    "tbz x0, #0, 107f\n"
+    "ld1 { v25.b }[2], [x12]\n"
     "b 107f\n"
     "106:"  // Oddments: Load (5, 3): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 107f\n"
-    "ld1 { v25.b }[0], [x20]\n"
+    "tbz x0, #0, 107f\n"
+    "ld1 { v25.b }[0], [x12]\n"
     "107:"  // Oddments: Load (5, 3): Bit 2: End
-    "ldr d3, [x3, #0xb8]\n"
-    "ssubl v25.8h, v25.8b, v7.8b\n"
-    "smlal v10.4s, v25.4h, v2.4h\n"
-    "ldr x13, [x25, #0x110]\n"
-    "ssubl v3.8h, v3.8b, v13.8b\n"
-    "smlal2 v9.4s, v25.8h, v2.8h\n"
-    "add x13, x13, x10\n"
+    "ldr d3, [x23, #0xb8]\n"
+    "ssubl v25.8h, v25.8b, v9.8b\n"
+    "ssubl v3.8h, v3.8b, v14.8b\n"
+    "ldr x14, [x20, #0x110]\n"
+    "smlal v6.4s, v25.4h, v2.4h\n"
+    "smlal2 v5.4s, v25.8h, v2.8h\n"
+    "add x14, x14, x24\n"
     "smlal v15.4s, v30.4h, v3.4h\n"
-    "smlal2 v20.4s, v30.8h, v3.8h\n"
-    "smlal v18.4s, v28.4h, v3.4h\n"
-    "smlal2 v5.4s, v28.8h, v3.8h\n"
-    "smlal v11.4s, v25.4h, v3.4h\n"
-    "smlal2 v8.4s, v25.8h, v3.8h\n"
-    "tbz x4, #2, 109f\n"
-    "ld1 { v24.s }[0], [x13], #0x4\n"
-    "tbz x4, #1, 108f\n"
-    "ld1 { v24.h }[2], [x13], #0x2\n"
-    "tbz x4, #0, 111f\n"
-    "ld1 { v24.b }[6], [x13]\n"
+    "smlal2 v16.4s, v30.8h, v3.8h\n"
+    "smlal v17.4s, v28.4h, v3.4h\n"
+    "smlal2 v8.4s, v28.8h, v3.8h\n"
+    "smlal v10.4s, v25.4h, v3.4h\n"
+    "smlal2 v7.4s, v25.8h, v3.8h\n"
+    "tbz x0, #2, 109f\n"
+    "ld1 { v24.s }[0], [x14], #0x4\n"
+    "tbz x0, #1, 108f\n"
+    "ld1 { v24.h }[2], [x14], #0x2\n"
+    "tbz x0, #0, 111f\n"
+    "ld1 { v24.b }[6], [x14]\n"
     "b 111f\n"
     "108:"  // Oddments: Load (5, 4): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 111f\n"
-    "ld1 { v24.b }[4], [x13]\n"
+    "tbz x0, #0, 111f\n"
+    "ld1 { v24.b }[4], [x14]\n"
     "b 111f\n"
     "109:"  // Oddments: Load (5, 4): Bit 2: Unset
-    "tbz x4, #1, 110f\n"
-    "ld1 { v24.h }[0], [x13], #0x2\n"
-    "tbz x4, #0, 111f\n"
-    "ld1 { v24.b }[2], [x13]\n"
+    "tbz x0, #1, 110f\n"
+    "ld1 { v24.h }[0], [x14], #0x2\n"
+    "tbz x0, #0, 111f\n"
+    "ld1 { v24.b }[2], [x14]\n"
     "b 111f\n"
     "110:"  // Oddments: Load (5, 4): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 111f\n"
-    "ld1 { v24.b }[0], [x13]\n"
+    "tbz x0, #0, 111f\n"
+    "ld1 { v24.b }[0], [x14]\n"
     "111:"  // Oddments: Load (5, 4): Bit 2: End
-    "ldr d4, [x3, #0xc0]\n"
-    "ssubl v24.8h, v24.8b, v7.8b\n"
-    "smlal v10.4s, v24.4h, v3.4h\n"
-    "ldr x21, [x25, #0x118]\n"
-    "ssubl v4.8h, v4.8b, v13.8b\n"
-    "smlal2 v9.4s, v24.8h, v3.8h\n"
-    "add x21, x21, x10\n"
+    "ldr d4, [x23, #0xc0]\n"
+    "ssubl v24.8h, v24.8b, v9.8b\n"
+    "ssubl v4.8h, v4.8b, v14.8b\n"
+    "ldr x21, [x20, #0x118]\n"
+    "smlal v6.4s, v24.4h, v3.4h\n"
+    "smlal2 v5.4s, v24.8h, v3.8h\n"
+    "add x21, x21, x24\n"
     "smlal v15.4s, v28.4h, v4.4h\n"
-    "smlal2 v20.4s, v28.8h, v4.8h\n"
-    "smlal v18.4s, v26.4h, v4.4h\n"
-    "smlal2 v5.4s, v26.8h, v4.8h\n"
-    "smlal v11.4s, v24.4h, v4.4h\n"
-    "smlal2 v8.4s, v24.8h, v4.8h\n"
-    "tbz x4, #2, 113f\n"
+    "smlal2 v16.4s, v28.8h, v4.8h\n"
+    "smlal v17.4s, v26.4h, v4.4h\n"
+    "smlal2 v8.4s, v26.8h, v4.8h\n"
+    "smlal v10.4s, v24.4h, v4.4h\n"
+    "smlal2 v7.4s, v24.8h, v4.8h\n"
+    "tbz x0, #2, 113f\n"
     "ld1 { v27.s }[0], [x21], #0x4\n"
-    "tbz x4, #1, 112f\n"
+    "tbz x0, #1, 112f\n"
     "ld1 { v27.h }[2], [x21], #0x2\n"
-    "tbz x4, #0, 115f\n"
+    "tbz x0, #0, 115f\n"
     "ld1 { v27.b }[6], [x21]\n"
     "b 115f\n"
     "112:"  // Oddments: Load (5, 5): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 115f\n"
+    "tbz x0, #0, 115f\n"
     "ld1 { v27.b }[4], [x21]\n"
     "b 115f\n"
     "113:"  // Oddments: Load (5, 5): Bit 2: Unset
-    "tbz x4, #1, 114f\n"
+    "tbz x0, #1, 114f\n"
     "ld1 { v27.h }[0], [x21], #0x2\n"
-    "tbz x4, #0, 115f\n"
+    "tbz x0, #0, 115f\n"
     "ld1 { v27.b }[2], [x21]\n"
     "b 115f\n"
     "114:"  // Oddments: Load (5, 5): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 115f\n"
+    "tbz x0, #0, 115f\n"
     "ld1 { v27.b }[0], [x21]\n"
     "115:"  // Oddments: Load (5, 5): Bit 2: End
-    "ssubl v27.8h, v27.8b, v7.8b\n"
-    "smlal v10.4s, v27.4h, v4.4h\n"
-    "smlal2 v9.4s, v27.8h, v4.8h\n"
-    "tbz x4, #2, 117f\n"
-    "ld1 { v6.4s }, [x2], #0x10\n"
-    "ld1 { v21.4s }, [x5], #0x10\n"
-    "tbz x4, #1, 116f\n"
-    "ld1 { v17.d }[0], [x2], #0x8\n"
-    "ld1 { v14.d }[0], [x5], #0x8\n"
-    "tbz x4, #0, 119f\n"
-    "ld1 { v17.s }[2], [x2]\n"
-    "ld1 { v14.s }[2], [x5]\n"
+    "ssubl v27.8h, v27.8b, v9.8b\n"
+    "smlal v6.4s, v27.4h, v4.4h\n"
+    "smlal2 v5.4s, v27.8h, v4.8h\n"
+    "tbz x0, #2, 117f\n"
+    "ld1 { v12.4s }, [x10], #0x10\n"
+    "ld1 { v19.4s }, [x1], #0x10\n"
+    "tbz x0, #1, 116f\n"
+    "ld1 { v20.d }[0], [x10], #0x8\n"
+    "ld1 { v29.d }[0], [x1], #0x8\n"
+    "tbz x0, #0, 119f\n"
+    "ld1 { v20.s }[2], [x10]\n"
+    "ld1 { v29.s }[2], [x1]\n"
     "b 119f\n"
     "116:"  // Oddments: Load requant params: Bit 2: Bit 1: Unset
-    "tbz x4, #0, 119f\n"
-    "ld1 { v17.s }[0], [x2]\n"
-    "ld1 { v14.s }[0], [x5]\n"
+    "tbz x0, #0, 119f\n"
+    "ld1 { v20.s }[0], [x10]\n"
+    "ld1 { v29.s }[0], [x1]\n"
     "b 119f\n"
     "117:"  // Oddments: Load requant params: Bit 2: Unset
-    "tbz x4, #1, 118f\n"
-    "ld1 { v6.d }[0], [x2], #0x8\n"
-    "ld1 { v21.d }[0], [x5], #0x8\n"
-    "tbz x4, #0, 119f\n"
-    "ld1 { v6.s }[2], [x2]\n"
-    "ld1 { v21.s }[2], [x5]\n"
+    "tbz x0, #1, 118f\n"
+    "ld1 { v12.d }[0], [x10], #0x8\n"
+    "ld1 { v19.d }[0], [x1], #0x8\n"
+    "tbz x0, #0, 119f\n"
+    "ld1 { v12.s }[2], [x10]\n"
+    "ld1 { v19.s }[2], [x1]\n"
     "b 119f\n"
     "118:"  // Oddments: Load requant params: Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 119f\n"
-    "ld1 { v6.s }[0], [x2]\n"
-    "ld1 { v21.s }[0], [x5]\n"
+    "tbz x0, #0, 119f\n"
+    "ld1 { v12.s }[0], [x10]\n"
+    "ld1 { v19.s }[0], [x1]\n"
     "119:"  // Oddments: Load requant params: Bit 2: End
-    "sqrdmulh v15.4s, v15.4s, v6.4s\n"
-    "add x17, x17, x1\n"
-    "sqrdmulh v20.4s, v20.4s, v17.4s\n"
-    "add x16, x16, x1\n"
-    "sqrdmulh v18.4s, v18.4s, v6.4s\n"
-    "add x6, x6, x1\n"
-    "sqrdmulh v5.4s, v5.4s, v17.4s\n"
-    "add x8, x8, x1\n"
-    "sqrdmulh v11.4s, v11.4s, v6.4s\n"
-    "and v1.16b, v15.16b, v21.16b\n"
-    "sshr v1.4s, v1.4s, #0x1f\n"
-    "and v29.16b, v20.16b, v14.16b\n"
-    "and v3.16b, v18.16b, v21.16b\n"
-    "sshr v29.4s, v29.4s, #0x1f\n"
-    "and v2.16b, v5.16b, v14.16b\n"
-    "and v0.16b, v11.16b, v21.16b\n"
-    "sshr v3.4s, v3.4s, #0x1f\n"
-    "sqrdmulh v8.4s, v8.4s, v17.4s\n"
+    "sqdmulh v15.4s, v15.4s, v12.4s\n"
+    "sqdmulh v17.4s, v17.4s, v12.4s\n"
+    "add x16, x16, x22\n"
+    "add x8, x8, x22\n"
+    "sqdmulh v10.4s, v10.4s, v12.4s\n"
+    "sqdmulh v6.4s, v6.4s, v12.4s\n"
+    "add x4, x4, x22\n"
+    "add x7, x7, x22\n"
+    "and v23.16b, v15.16b, v19.16b\n"
+    "sqdmulh v16.4s, v16.4s, v20.4s\n"
+    "and v22.16b, v17.16b, v19.16b\n"
+    "sqdmulh v8.4s, v8.4s, v20.4s\n"
+    "and v21.16b, v10.16b, v19.16b\n"
+    "sqdmulh v7.4s, v7.4s, v20.4s\n"
+    "and v26.16b, v6.16b, v19.16b\n"
+    "sqdmulh v5.4s, v5.4s, v20.4s\n"
+    "sshr v23.4s, v23.4s, #0x1f\n"
+    "and v4.16b, v16.16b, v29.16b\n"
+    "sshr v22.4s, v22.4s, #0x1f\n"
+    "and v2.16b, v8.16b, v29.16b\n"
+    "sshr v21.4s, v21.4s, #0x1f\n"
+    "and v3.16b, v7.16b, v29.16b\n"
+    "sshr v26.4s, v26.4s, #0x1f\n"
+    "and v25.16b, v5.16b, v29.16b\n"
+    "sqadd v15.4s, v15.4s, v23.4s\n"
+    "sshr v4.4s, v4.4s, #0x1f\n"
+    "sqadd v17.4s, v17.4s, v22.4s\n"
     "sshr v2.4s, v2.4s, #0x1f\n"
-    "sqadd v15.4s, v15.4s, v1.4s\n"
-    "sqrdmulh v10.4s, v10.4s, v6.4s\n"
-    "sshr v0.4s, v0.4s, #0x1f\n"
-    "sqrdmulh v9.4s, v9.4s, v17.4s\n"
-    "sqadd v20.4s, v20.4s, v29.4s\n"
-    "sqadd v18.4s, v18.4s, v3.4s\n"
-    "srshl v15.4s, v15.4s, v21.4s\n"
-    "sqadd v5.4s, v5.4s, v2.4s\n"
-    "srshl v20.4s, v20.4s, v14.4s\n"
-    "srshl v18.4s, v18.4s, v21.4s\n"
-    "add v15.4s, v15.4s, v19.4s\n"
-    "srshl v5.4s, v5.4s, v14.4s\n"
-    "add v20.4s, v20.4s, v19.4s\n"
-    "smin v15.4s, v15.4s, v12.4s\n"
-    "add v18.4s, v18.4s, v19.4s\n"
-    "smin v20.4s, v20.4s, v12.4s\n"
-    "smax v15.4s, v15.4s, v16.4s\n"
-    "smin v18.4s, v18.4s, v12.4s\n"
-    "smax v20.4s, v20.4s, v16.4s\n"
-    "add v5.4s, v5.4s, v19.4s\n"
-    "smax v18.4s, v18.4s, v16.4s\n"
-    "uzp1 v15.16b, v15.16b, v20.16b\n"
-    "smin v5.4s, v5.4s, v12.4s\n"
+    "sqadd v10.4s, v10.4s, v21.4s\n"
+    "sshr v3.4s, v3.4s, #0x1f\n"
+    "sqadd v6.4s, v6.4s, v26.4s\n"
+    "sshr v25.4s, v25.4s, #0x1f\n"
+    "srshl v15.4s, v15.4s, v19.4s\n"
+    "sqadd v16.4s, v16.4s, v4.4s\n"
+    "srshl v17.4s, v17.4s, v19.4s\n"
+    "sqadd v8.4s, v8.4s, v2.4s\n"
+    "srshl v10.4s, v10.4s, v19.4s\n"
+    "sqadd v7.4s, v7.4s, v3.4s\n"
+    "srshl v6.4s, v6.4s, v19.4s\n"
+    "sqadd v5.4s, v5.4s, v25.4s\n"
+    "srshl v16.4s, v16.4s, v29.4s\n"
+    "sqxtn v15.4h, v15.4s\n"
+    "srshl v8.4s, v8.4s, v29.4s\n"
+    "sqxtn v17.4h, v17.4s\n"
+    "srshl v7.4s, v7.4s, v29.4s\n"
+    "sqxtn v10.4h, v10.4s\n"
+    "srshl v5.4s, v5.4s, v29.4s\n"
+    "sqxtn v6.4h, v6.4s\n"
+    "sqxtn2 v15.8h, v16.4s\n"
+    "sqxtn2 v17.8h, v8.4s\n"
+    "sqxtn2 v10.8h, v7.4s\n"
+    "sqxtn2 v6.8h, v5.4s\n"
+    "sqadd v15.8h, v15.8h, v18.8h\n"
+    "sqadd v17.8h, v17.8h, v18.8h\n"
+    "sqadd v10.8h, v10.8h, v18.8h\n"
+    "sqadd v6.8h, v6.8h, v18.8h\n"
+    "smax v15.8h, v15.8h, v11.8h\n"
+    "smax v17.8h, v17.8h, v11.8h\n"
+    "smax v10.8h, v10.8h, v11.8h\n"
+    "smax v6.8h, v6.8h, v11.8h\n"
+    "smin v15.8h, v15.8h, v13.8h\n"
+    "smin v17.8h, v17.8h, v13.8h\n"
+    "smin v10.8h, v10.8h, v13.8h\n"
+    "smin v6.8h, v6.8h, v13.8h\n"
     "uzp1 v15.16b, v15.16b, v15.16b\n"
-    "sqadd v11.4s, v11.4s, v0.4s\n"
-    "smax v5.4s, v5.4s, v16.4s\n"
-    "and v27.16b, v8.16b, v14.16b\n"
-    "sshr v27.4s, v27.4s, #0x1f\n"
-    "uzp1 v18.16b, v18.16b, v5.16b\n"
-    "srshl v11.4s, v11.4s, v21.4s\n"
-    "and v30.16b, v10.16b, v21.16b\n"
-    "sshr v30.4s, v30.4s, #0x1f\n"
-    "uzp1 v18.16b, v18.16b, v18.16b\n"
-    "add v11.4s, v11.4s, v19.4s\n"
-    "sqadd v8.4s, v8.4s, v27.4s\n"
-    "and v6.16b, v9.16b, v14.16b\n"
-    "sshr v6.4s, v6.4s, #0x1f\n"
-    "smin v11.4s, v11.4s, v12.4s\n"
-    "srshl v8.4s, v8.4s, v14.4s\n"
-    "sqadd v10.4s, v10.4s, v30.4s\n"
-    "smax v11.4s, v11.4s, v16.4s\n"
-    "add v8.4s, v8.4s, v19.4s\n"
-    "srshl v10.4s, v10.4s, v21.4s\n"
-    "sqadd v9.4s, v9.4s, v6.4s\n"
-    "smin v8.4s, v8.4s, v12.4s\n"
-    "add v10.4s, v10.4s, v19.4s\n"
-    "srshl v9.4s, v9.4s, v14.4s\n"
-    "smax v8.4s, v8.4s, v16.4s\n"
-    "smin v10.4s, v10.4s, v12.4s\n"
-    "uzp1 v11.16b, v11.16b, v8.16b\n"
-    "add v9.4s, v9.4s, v19.4s\n"
-    "uzp1 v11.16b, v11.16b, v11.16b\n"
-    "smax v10.4s, v10.4s, v16.4s\n"
-    "smin v9.4s, v9.4s, v12.4s\n"
-    "smax v9.4s, v9.4s, v16.4s\n"
-    "uzp1 v10.16b, v10.16b, v9.16b\n"
+    "uzp1 v17.16b, v17.16b, v17.16b\n"
     "uzp1 v10.16b, v10.16b, v10.16b\n"
-    "tbz x4, #2, 121f\n"
-    "st1 { v15.s }[0], [x17], #0x4\n"
-    "st1 { v18.s }[0], [x16], #0x4\n"
-    "st1 { v11.s }[0], [x6], #0x4\n"
-    "st1 { v10.s }[0], [x8], #0x4\n"
-    "tbz x4, #1, 120f\n"
-    "st1 { v15.h }[2], [x17], #0x2\n"
-    "st1 { v18.h }[2], [x16], #0x2\n"
-    "st1 { v11.h }[2], [x6], #0x2\n"
-    "st1 { v10.h }[2], [x8], #0x2\n"
-    "tbz x4, #0, 123f\n"
-    "st1 { v15.b }[6], [x17], #0x1\n"
-    "st1 { v18.b }[6], [x16], #0x1\n"
-    "st1 { v11.b }[6], [x6], #0x1\n"
-    "st1 { v10.b }[6], [x8], #0x1\n"
+    "uzp1 v6.16b, v6.16b, v6.16b\n"
+    "tbz x0, #2, 121f\n"
+    "st1 { v15.s }[0], [x16], #0x4\n"
+    "st1 { v17.s }[0], [x8], #0x4\n"
+    "st1 { v10.s }[0], [x4], #0x4\n"
+    "st1 { v6.s }[0], [x7], #0x4\n"
+    "tbz x0, #1, 120f\n"
+    "st1 { v15.h }[2], [x16], #0x2\n"
+    "st1 { v17.h }[2], [x8], #0x2\n"
+    "st1 { v10.h }[2], [x4], #0x2\n"
+    "st1 { v6.h }[2], [x7], #0x2\n"
+    "tbz x0, #0, 123f\n"
+    "st1 { v15.b }[6], [x16], #0x1\n"
+    "st1 { v17.b }[6], [x8], #0x1\n"
+    "st1 { v10.b }[6], [x4], #0x1\n"
+    "st1 { v6.b }[6], [x7], #0x1\n"
     "b 123f\n"
     "120:"  // Oddments: Bit 2: Bit 1: Unset
-    "tbz x4, #0, 123f\n"
-    "st1 { v15.b }[4], [x17], #0x1\n"
-    "st1 { v18.b }[4], [x16], #0x1\n"
-    "st1 { v11.b }[4], [x6], #0x1\n"
-    "st1 { v10.b }[4], [x8], #0x1\n"
+    "tbz x0, #0, 123f\n"
+    "st1 { v15.b }[4], [x16], #0x1\n"
+    "st1 { v17.b }[4], [x8], #0x1\n"
+    "st1 { v10.b }[4], [x4], #0x1\n"
+    "st1 { v6.b }[4], [x7], #0x1\n"
     "b 123f\n"
     "121:"  // Oddments: Bit 2: Unset
-    "tbz x4, #1, 122f\n"
-    "st1 { v15.h }[0], [x17], #0x2\n"
-    "st1 { v18.h }[0], [x16], #0x2\n"
-    "st1 { v11.h }[0], [x6], #0x2\n"
-    "st1 { v10.h }[0], [x8], #0x2\n"
-    "tbz x4, #0, 123f\n"
-    "st1 { v15.b }[2], [x17], #0x1\n"
-    "st1 { v18.b }[2], [x16], #0x1\n"
-    "st1 { v11.b }[2], [x6], #0x1\n"
-    "st1 { v10.b }[2], [x8], #0x1\n"
+    "tbz x0, #1, 122f\n"
+    "st1 { v15.h }[0], [x16], #0x2\n"
+    "st1 { v17.h }[0], [x8], #0x2\n"
+    "st1 { v10.h }[0], [x4], #0x2\n"
+    "st1 { v6.h }[0], [x7], #0x2\n"
+    "tbz x0, #0, 123f\n"
+    "st1 { v15.b }[2], [x16], #0x1\n"
+    "st1 { v17.b }[2], [x8], #0x1\n"
+    "st1 { v10.b }[2], [x4], #0x1\n"
+    "st1 { v6.b }[2], [x7], #0x1\n"
     "b 123f\n"
     "122:"  // Oddments: Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 123f\n"
-    "st1 { v15.b }[0], [x17], #0x1\n"
-    "st1 { v18.b }[0], [x16], #0x1\n"
-    "st1 { v11.b }[0], [x6], #0x1\n"
-    "st1 { v10.b }[0], [x8], #0x1\n"
+    "tbz x0, #0, 123f\n"
+    "st1 { v15.b }[0], [x16], #0x1\n"
+    "st1 { v17.b }[0], [x8], #0x1\n"
+    "st1 { v10.b }[0], [x4], #0x1\n"
+    "st1 { v6.b }[0], [x7], #0x1\n"
     "123:"  // Oddments: Bit 2: End
-
     "124:"  // End
-
     :
     : [offsetof_Params_bias] "I" (offsetof(Params, bias)), [offsetof_Params_inptrs] "I" (offsetof(Params, inptrs)), [offsetof_Params_n_channels] "I" (offsetof(Params, n_channels)), [offsetof_Params_outptrs] "I" (offsetof(Params, outptrs)), [offsetof_Params_requant] "I" (offsetof(Params, requant)), [offsetof_Params_requant_muls] "I" (offsetof(Params, requant_muls)), [offsetof_Params_requant_shifts] "I" (offsetof(Params, requant_shifts)), [offsetof_Params_weights] "I" (offsetof(Params, weights)), [offsetof_Requantize32_a_offset] "I" (offsetof(arm_gemm::Requantize32, a_offset)), [offsetof_Requantize32_b_offset] "I" (offsetof(arm_gemm::Requantize32, b_offset)), [offsetof_Requantize32_c_offset] "I" (offsetof(arm_gemm::Requantize32, c_offset)), [offsetof_Requantize32_maxval] "I" (offsetof(arm_gemm::Requantize32, maxval)), [offsetof_Requantize32_minval] "I" (offsetof(arm_gemm::Requantize32, minval)), [params] "r" (&params)
     : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17", "x19", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_s8q_nhwc_generic_output9_mla_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_s8q_nhwc_generic_output9_mla_depthfirst.hpp
index 4e845cc..9b1f7c2 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_s8q_nhwc_generic_output9_mla_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_s8q_nhwc_generic_output9_mla_depthfirst.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -28,28 +28,23 @@
 
 #pragma once
 
+#if defined(__aarch64__)
+
 namespace arm_conv {
 namespace depthwise {
 
 void a64_s8q_nhwc_generic_output9_mla_depthfirst_impl(const int8_t *const *const, int8_t *const *const, const void *, const arm_gemm::Requantize32&, const unsigned int, const unsigned int);
 
-struct a64_s8q_nhwc_generic_output9_mla_depthfirst
+class a64_s8q_nhwc_generic_output9_mla_depthfirst : public GenericDepthfirstKernelStrategy<int8_t, int8_t, int8_t, int32_t>
 {
-  typedef int32_t bias_type;
-  typedef int8_t input_type;
-  typedef int8_t weight_type;
-  typedef int8_t return_type;
+  KernelType kernel = a64_s8q_nhwc_generic_output9_mla_depthfirst_impl;
 
-  typedef void (*kern_type)(const int8_t *const *const, int8_t *const *const, const void *, const arm_gemm::Requantize32&, const unsigned int, const unsigned int);
+  public:
+  a64_s8q_nhwc_generic_output9_mla_depthfirst(const CPUInfo *) : GenericDepthfirstKernelStrategy<int8_t, int8_t, int8_t, int32_t>(9, arm_gemm::VLType::None) {}
 
-  constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::None;
-
-  constexpr static unsigned int n_output_points = 9;
-
-  kern_type kernel = a64_s8q_nhwc_generic_output9_mla_depthfirst_impl;
-
-  a64_s8q_nhwc_generic_output9_mla_depthfirst(const CPUInfo *) {}
+  KernelType get_kernel() const override { return kernel; }
 };
 
 }  // namespace depthwise
 }  // namespace arm_conv
+#endif // defined(__aarch64__)
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_s8q_packed_to_nhwc_3x3_s2_with_multiplier_output2x4_dot_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_s8q_packed_to_nhwc_3x3_s2_with_multiplier_output2x4_dot_depthfirst.hpp
index b9fef4f..5ca3ccd 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_s8q_packed_to_nhwc_3x3_s2_with_multiplier_output2x4_dot_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_s8q_packed_to_nhwc_3x3_s2_with_multiplier_output2x4_dot_depthfirst.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -28,39 +28,34 @@
 
 #pragma once
 
+#if defined(__aarch64__)
+
 namespace arm_conv {
 namespace depthwise {
 
 void a64_s8q_packed_to_nhwc_3x3_s2_with_multiplier_output2x4_dot_depthfirst_impl(const int8_t *const *const, int8_t *const *const, const void *, unsigned int, const arm_gemm::Requantize32&);
 
-struct a64_s8q_packed_to_nhwc_3x3_s2_with_multiplier_output2x4_dot_depthfirst
+struct a64_s8q_packed_to_nhwc_3x3_s2_with_multiplier_output2x4_dot_depthfirst : DepthfirstMultiplierStrategy<int8_t, int8_t, int8_t, int32_t>
 {
-  typedef int32_t bias_type;
-  typedef int8_t input_type;
-  typedef int8_t weight_type;
-  typedef int8_t return_type;
-
-  typedef void (*kern_type)(const int8_t *const *const, int8_t *const *const, const void *, unsigned int, const arm_gemm::Requantize32&);
-
-  constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::None;
-
+  using Parent = DepthfirstMultiplierStrategy<int8_t, int8_t, int8_t, int32_t>;
   constexpr static unsigned int kernel_rows = 3;
   constexpr static unsigned int kernel_cols = 3;
 
   constexpr static unsigned int stride_rows = 2;
   constexpr static unsigned int stride_cols = 2;
 
-  constexpr static unsigned int output_rows = 2;
-  constexpr static unsigned int output_cols = 4;
+  a64_s8q_packed_to_nhwc_3x3_s2_with_multiplier_output2x4_dot_depthfirst(const CPUInfo *)
+  : Parent(2, 4, kernel_rows, kernel_cols, stride_rows, stride_cols)
+  {
+  }
 
-  constexpr static unsigned int input_rows = 5;
-  constexpr static unsigned int input_cols = 9;
-  constexpr static unsigned int input_col_quads = 1;
+  arm_gemm::VLType get_vl_type() const override { return arm_gemm::VLType::None; }
 
-  kern_type kernel = a64_s8q_packed_to_nhwc_3x3_s2_with_multiplier_output2x4_dot_depthfirst_impl;
-
-  a64_s8q_packed_to_nhwc_3x3_s2_with_multiplier_output2x4_dot_depthfirst(const CPUInfo *) {}
+  Parent::KernelType kernel = a64_s8q_packed_to_nhwc_3x3_s2_with_multiplier_output2x4_dot_depthfirst_impl;
+  Parent::KernelType get_kernel(void) const override { return kernel; }
 };
 
 }  // namespace depthwise
 }  // namespace arm_conv
+
+#endif // defined(__aarch64__)
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_s8q_packed_to_nhwc_5x5_s1_with_multiplier_output4x2_dot_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_s8q_packed_to_nhwc_5x5_s1_with_multiplier_output4x2_dot_depthfirst.hpp
index 9a3eed4..0641229 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_s8q_packed_to_nhwc_5x5_s1_with_multiplier_output4x2_dot_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_s8q_packed_to_nhwc_5x5_s1_with_multiplier_output4x2_dot_depthfirst.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -26,6 +26,8 @@
 
 #include <cstdint>
 
+#if defined(__aarch64__)
+
 #pragma once
 
 namespace arm_conv {
@@ -33,34 +35,26 @@
 
 void a64_s8q_packed_to_nhwc_5x5_s1_with_multiplier_output4x2_dot_depthfirst_impl(const int8_t *const *const, int8_t *const *const, const void *, unsigned int, const arm_gemm::Requantize32&);
 
-struct a64_s8q_packed_to_nhwc_5x5_s1_with_multiplier_output4x2_dot_depthfirst
+struct a64_s8q_packed_to_nhwc_5x5_s1_with_multiplier_output4x2_dot_depthfirst : DepthfirstMultiplierStrategy<int8_t, int8_t, int8_t, int32_t>
 {
-  typedef int32_t bias_type;
-  typedef int8_t input_type;
-  typedef int8_t weight_type;
-  typedef int8_t return_type;
-
-  typedef void (*kern_type)(const int8_t *const *const, int8_t *const *const, const void *, unsigned int, const arm_gemm::Requantize32&);
-
-  constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::None;
-
+  using Parent = DepthfirstMultiplierStrategy<int8_t, int8_t, int8_t, int32_t>;
   constexpr static unsigned int kernel_rows = 5;
   constexpr static unsigned int kernel_cols = 5;
 
   constexpr static unsigned int stride_rows = 1;
   constexpr static unsigned int stride_cols = 1;
 
-  constexpr static unsigned int output_rows = 4;
-  constexpr static unsigned int output_cols = 2;
+  a64_s8q_packed_to_nhwc_5x5_s1_with_multiplier_output4x2_dot_depthfirst(const CPUInfo *)
+  : Parent(4, 2, kernel_rows, kernel_cols, stride_rows, stride_cols)
+  {
+  }
 
-  constexpr static unsigned int input_rows = 8;
-  constexpr static unsigned int input_cols = 6;
-  constexpr static unsigned int input_col_quads = 1;
+  arm_gemm::VLType get_vl_type() const override { return arm_gemm::VLType::None; }
 
-  kern_type kernel = a64_s8q_packed_to_nhwc_5x5_s1_with_multiplier_output4x2_dot_depthfirst_impl;
-
-  a64_s8q_packed_to_nhwc_5x5_s1_with_multiplier_output4x2_dot_depthfirst(const CPUInfo *) {}
+  Parent::KernelType kernel = a64_s8q_packed_to_nhwc_5x5_s1_with_multiplier_output4x2_dot_depthfirst_impl;
+  Parent::KernelType get_kernel(void) const override { return kernel; }
 };
 
 }  // namespace depthwise
 }  // namespace arm_conv
+#endif // defined(__aarch64__)
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_s8q_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_s8q_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst.hpp
index d0ae00d..3dad8d5 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_s8q_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_s8q_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -28,31 +28,25 @@
 
 #pragma once
 
+#if defined(__aarch64__)
+
 namespace arm_conv {
 namespace depthwise {
 
 void a64_s8q_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst_impl(const int8_t *const *const, int8_t *const *const, const int8_t *, const int32_t *, const unsigned int, const unsigned int, const int32_t *, const int32_t *, const int32_t *, const arm_gemm::Requantize32&);
 
-struct a64_s8q_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst
+struct a64_s8q_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst : GenericDepthfirstMultiplierKernelStrategy<int8_t, int8_t, int8_t, int32_t>
 {
-  typedef int32_t bias_type;
-  typedef int8_t input_type;
-  typedef int8_t weight_type;
-  typedef int8_t return_type;
-
-  typedef void (*kern_type)(const int8_t *const *const, int8_t *const *const, const int8_t *, const int32_t *, const unsigned int, const unsigned int, const int32_t *, const int32_t *, const int32_t *, const arm_gemm::Requantize32&);
-
-  constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::None;
-
-  constexpr static unsigned int output_rows(void) { return 2; };
-  constexpr static unsigned int output_cols(void) { return 8; };
-
-  constexpr static unsigned int output_col_regs(void) { return 2; };
-
-  kern_type kernel = a64_s8q_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst_impl;
-
-  a64_s8q_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst(const CPUInfo *) {}
+  using Parent = GenericDepthfirstMultiplierKernelStrategy<int8_t, int8_t, int8_t, int32_t>;
+  a64_s8q_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst(const CPUInfo *)
+  : Parent(2, 8, arm_gemm::VLType::None)
+  {
+  }
+  Parent::KernelType kernel = a64_s8q_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst_impl;
+  Parent::KernelType get_kernel(void) const override { return kernel; }
 };
 
 }  // namespace depthwise
 }  // namespace arm_conv
+
+#endif  // defined(__aarch64__)
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_s8qs_nhwc_3x3_s1_output2x2_dot_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_s8qs_nhwc_3x3_s1_output2x2_dot_depthfirst.hpp
index 0fde00b..22b6b65 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_s8qs_nhwc_3x3_s1_output2x2_dot_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_s8qs_nhwc_3x3_s1_output2x2_dot_depthfirst.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -34,39 +34,40 @@
 namespace arm_conv {
 namespace depthwise {
 
-void a64_s8qs_nhwc_3x3_s1_output2x2_dot_depthfirst_impl(const int8_t *const *, int8_t *const *, const void *, uint64_t, const arm_gemm::Requantize32&);
+void a64_s8qs_nhwc_3x3_s1_output2x2_dot_depthfirst_impl(unsigned int, const int8_t *const *, const int8_t *, const int32_t *, const arm_gemm::Requantize32&, const int32_t *, const int32_t *, int8_t *const *);
 
-struct a64_s8qs_nhwc_3x3_s1_output2x2_dot_depthfirst
+class a64_s8qs_nhwc_3x3_s1_output2x2_dot_depthfirst : public DepthwiseDepthfirstStrategy<int8_t, int8_t, int8_t, int32_t>
 {
-  typedef int32_t bias_type;
-  typedef int8_t input_type;
-  typedef int8_t weight_type;
-  typedef int8_t return_type;
+  using Parent = DepthwiseDepthfirstStrategy<int8_t, int8_t, int8_t, int32_t>;
 
-  constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::None;
-
-  typedef void (*kern_type)(const int8_t *const *, int8_t *const *, const void *, uint64_t, const arm_gemm::Requantize32&);
-  typedef void (*parameter_packing_fn)(unsigned int, void *, const int32_t *, const int8_t *, const arm_gemm::Requantize32 &, size_t, size_t);
-  typedef size_t (*parameter_sizing_fn)(const DepthwiseArgs &);
-
+  public:
   constexpr static unsigned int kernel_rows = 3;
   constexpr static unsigned int kernel_cols = 3;
 
   constexpr static unsigned int stride_rows = 1;
   constexpr static unsigned int stride_cols = 1;
 
-  constexpr static unsigned int output_rows = 2;
-  constexpr static unsigned int output_cols = 2;
+  a64_s8qs_nhwc_3x3_s1_output2x2_dot_depthfirst(const CPUInfo *) : Parent(2, 2, 3, 3, 1, 1) {}
 
-  constexpr static unsigned int input_rows = 4;
-  constexpr static unsigned int input_cols = 4;
+  arm_gemm::VLType get_vl_type(void) const override { return arm_gemm::VLType::None; }
 
-  constexpr static parameter_packing_fn pack_parameters = interleave_a64_s8q_3x3_dot::pack_parameters;
-  constexpr static parameter_sizing_fn get_packed_size = interleave_a64_s8q_3x3_dot::get_packed_size;
+  Parent::KernelType kernel = a64_s8qs_nhwc_3x3_s1_output2x2_dot_depthfirst_impl;
+  Parent::KernelType get_kernel(void) const override { return kernel; }
+  size_t get_storage_size(const DepthwiseArgs &args) const override
+  {
+    return interleave_a64_s8q_3x3_dot::get_packed_size(args);
+  }
 
-  kern_type kernel = a64_s8qs_nhwc_3x3_s1_output2x2_dot_depthfirst_impl;
-
-  a64_s8qs_nhwc_3x3_s1_output2x2_dot_depthfirst(const CPUInfo *) {}
+  void pack_parameters(
+    const DepthwiseArgs &args, void *buffer, const void *biases, const arm_gemm::Requantize32 &qp,
+    const void *weights, size_t ld_weight_col, size_t ld_weight_row
+  ) const override
+  {
+    interleave_a64_s8q_3x3_dot::pack_parameters(
+      args.input_channels, buffer, reinterpret_cast<const int32_t *>(biases),
+      reinterpret_cast<const int8_t *>(weights), qp, ld_weight_col, ld_weight_row
+    );
+  }
 };
 
 }  // namespace depthwise
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_s8qs_nhwc_3x3_s1_output2x2_dot_depthfirst/generic.cpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_s8qs_nhwc_3x3_s1_output2x2_dot_depthfirst/generic.cpp
index bdbda17..761c7ec 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_s8qs_nhwc_3x3_s1_output2x2_dot_depthfirst/generic.cpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_s8qs_nhwc_3x3_s1_output2x2_dot_depthfirst/generic.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -30,7 +30,15 @@
 namespace arm_conv {
 namespace depthwise {
 
-void a64_s8qs_nhwc_3x3_s1_output2x2_dot_depthfirst_impl(const int8_t *const *const inptrs, int8_t *const *const outptrs, const void *params, const uint64_t n_channels, const arm_gemm::Requantize32& qp)
+void a64_s8qs_nhwc_3x3_s1_output2x2_dot_depthfirst_impl(
+  const unsigned int n_channels,
+  const int8_t *const *const inptrs,
+  const int8_t *params,
+  const int32_t *,  // Bias, should be wrapped into the parameters
+  const arm_gemm::Requantize32& qp,
+  const int32_t *, const int32_t *,  // Requant parameters, also wrapped
+  int8_t *const *const outptrs
+)
 {
   __asm__ __volatile__(
     "ldp x15, x14, [%x[inptrs], #0x0]\n"
@@ -1173,7 +1181,7 @@
     "34:"  // End
     "add SP, SP, #0x80\n"
     : [params] "+&r" (params)
-    : [inptrs] "r" (inptrs), [n_channels] "r" (n_channels), [offsetof_Requantize32_c_offset] "I" (offsetof(arm_gemm::Requantize32, c_offset)), [offsetof_Requantize32_maxval] "I" (offsetof(arm_gemm::Requantize32, maxval)), [offsetof_Requantize32_minval] "I" (offsetof(arm_gemm::Requantize32, minval)), [outptrs] "r" (outptrs), [qp] "r" (&qp)
+    : [inptrs] "r" (inptrs), [n_channels] "r" ((long unsigned int) n_channels), [offsetof_Requantize32_c_offset] "I" (offsetof(arm_gemm::Requantize32, c_offset)), [offsetof_Requantize32_maxval] "I" (offsetof(arm_gemm::Requantize32, maxval)), [offsetof_Requantize32_minval] "I" (offsetof(arm_gemm::Requantize32, minval)), [outptrs] "r" (outptrs), [qp] "r" (&qp)
     : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x11", "x12", "x13", "x14", "x15", "x19", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"
   );
 }
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8q_nhwc_3x3_s1_output2x2_dot_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8q_nhwc_3x3_s1_output2x2_dot_depthfirst.hpp
index 05eddd1..00c8a3c 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8q_nhwc_3x3_s1_output2x2_dot_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8q_nhwc_3x3_s1_output2x2_dot_depthfirst.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -34,39 +34,40 @@
 namespace arm_conv {
 namespace depthwise {
 
-void a64_u8q_nhwc_3x3_s1_output2x2_dot_depthfirst_impl(const uint8_t *const *, uint8_t *const *, const void *, uint64_t, const arm_gemm::Requantize32&);
+void a64_u8q_nhwc_3x3_s1_output2x2_dot_depthfirst_impl(unsigned int, const uint8_t *const *, const uint8_t *, const int32_t *, const arm_gemm::Requantize32&, const int32_t *, const int32_t *, uint8_t *const *);
 
-struct a64_u8q_nhwc_3x3_s1_output2x2_dot_depthfirst
+class a64_u8q_nhwc_3x3_s1_output2x2_dot_depthfirst : public DepthwiseDepthfirstStrategy<uint8_t, uint8_t, uint8_t, int32_t>
 {
-  typedef uint32_t bias_type;
-  typedef uint8_t input_type;
-  typedef uint8_t weight_type;
-  typedef uint8_t return_type;
+  using Parent = DepthwiseDepthfirstStrategy<uint8_t, uint8_t, uint8_t, int32_t>;
 
-  constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::None;
-
-  typedef void (*kern_type)(const uint8_t *const *, uint8_t *const *, const void *, uint64_t, const arm_gemm::Requantize32&);
-  typedef void (*parameter_packing_fn)(unsigned int, void *, const int32_t *, const uint8_t *, const arm_gemm::Requantize32 &, size_t, size_t);
-  typedef size_t (*parameter_sizing_fn)(const DepthwiseArgs &);
-
+  public:
   constexpr static unsigned int kernel_rows = 3;
   constexpr static unsigned int kernel_cols = 3;
 
   constexpr static unsigned int stride_rows = 1;
   constexpr static unsigned int stride_cols = 1;
 
-  constexpr static unsigned int output_rows = 2;
-  constexpr static unsigned int output_cols = 2;
+  a64_u8q_nhwc_3x3_s1_output2x2_dot_depthfirst(const CPUInfo *) : Parent(2, 2, 3, 3, 1, 1) {}
 
-  constexpr static unsigned int input_rows = 4;
-  constexpr static unsigned int input_cols = 4;
+  arm_gemm::VLType get_vl_type(void) const override { return arm_gemm::VLType::None; }
 
-  constexpr static parameter_packing_fn pack_parameters = interleave_a64_u8q_3x3_dot::pack_parameters;
-  constexpr static parameter_sizing_fn get_packed_size = interleave_a64_u8q_3x3_dot::get_packed_size;
+  Parent::KernelType kernel = a64_u8q_nhwc_3x3_s1_output2x2_dot_depthfirst_impl;
+  Parent::KernelType get_kernel(void) const override { return kernel; }
+  size_t get_storage_size(const DepthwiseArgs &args) const override
+  {
+    return interleave_a64_u8q_3x3_dot::get_packed_size(args);
+  }
 
-  kern_type kernel = a64_u8q_nhwc_3x3_s1_output2x2_dot_depthfirst_impl;
-
-  a64_u8q_nhwc_3x3_s1_output2x2_dot_depthfirst(const CPUInfo *) {}
+  void pack_parameters(
+    const DepthwiseArgs &args, void *buffer, const void *biases, const arm_gemm::Requantize32 &qp,
+    const void *weights, size_t ld_weight_col, size_t ld_weight_row
+  ) const override
+  {
+    interleave_a64_u8q_3x3_dot::pack_parameters(
+      args.input_channels, buffer, reinterpret_cast<const int32_t *>(biases),
+      reinterpret_cast<const uint8_t *>(weights), qp, ld_weight_col, ld_weight_row
+    );
+  }
 };
 
 }  // namespace depthwise
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8q_nhwc_3x3_s1_output2x2_dot_depthfirst/generic.cpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8q_nhwc_3x3_s1_output2x2_dot_depthfirst/generic.cpp
index 22c584f..64b305c 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8q_nhwc_3x3_s1_output2x2_dot_depthfirst/generic.cpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8q_nhwc_3x3_s1_output2x2_dot_depthfirst/generic.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -30,7 +30,15 @@
 namespace arm_conv {
 namespace depthwise {
 
-void a64_u8q_nhwc_3x3_s1_output2x2_dot_depthfirst_impl(const uint8_t *const *const inptrs, uint8_t *const *const outptrs, const void *params, const uint64_t n_channels, const arm_gemm::Requantize32& qp)
+void a64_u8q_nhwc_3x3_s1_output2x2_dot_depthfirst_impl(
+  const unsigned int n_channels,
+  const uint8_t *const *const inptrs,
+  const uint8_t *params,
+  const int32_t *,  // Bias, should be wrapped into the parameters
+  const arm_gemm::Requantize32& qp,
+  const int32_t *, const int32_t *,  // Requant parameters, also wrapped
+  uint8_t *const *const outptrs
+)
 {
   __asm__ __volatile__(
     "ldp x13, x12, [%x[inptrs], #0x0]\n"
@@ -1307,7 +1315,7 @@
     "34:"  // End
     "add SP, SP, #0x80\n"
     : [params] "+&r" (params)
-    : [inptrs] "r" (inptrs), [n_channels] "r" (n_channels), [offsetof_Requantize32_b_offset] "I" (offsetof(arm_gemm::Requantize32, b_offset)), [offsetof_Requantize32_c_offset] "I" (offsetof(arm_gemm::Requantize32, c_offset)), [offsetof_Requantize32_maxval] "I" (offsetof(arm_gemm::Requantize32, maxval)), [offsetof_Requantize32_minval] "I" (offsetof(arm_gemm::Requantize32, minval)), [outptrs] "r" (outptrs), [qp] "r" (&qp)
+    : [inptrs] "r" (inptrs), [n_channels] "r" ((long unsigned) n_channels), [offsetof_Requantize32_b_offset] "I" (offsetof(arm_gemm::Requantize32, b_offset)), [offsetof_Requantize32_c_offset] "I" (offsetof(arm_gemm::Requantize32, c_offset)), [offsetof_Requantize32_maxval] "I" (offsetof(arm_gemm::Requantize32, maxval)), [offsetof_Requantize32_minval] "I" (offsetof(arm_gemm::Requantize32, minval)), [outptrs] "r" (outptrs), [qp] "r" (&qp)
     : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x11", "x12", "x13", "x19", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"
   );
 }
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8q_nhwc_3x3_s1_output2x2_mla_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8q_nhwc_3x3_s1_output2x2_mla_depthfirst.hpp
index 09ba75f..b55055f 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8q_nhwc_3x3_s1_output2x2_mla_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8q_nhwc_3x3_s1_output2x2_mla_depthfirst.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -36,37 +36,24 @@
 
 void a64_u8q_nhwc_3x3_s1_output2x2_mla_depthfirst_impl(unsigned int, const uint8_t *const *, const uint8_t *, const int32_t *, const arm_gemm::Requantize32 &, const int32_t *, const int32_t *, uint8_t *const *);
 
-struct a64_u8q_nhwc_3x3_s1_output2x2_mla_depthfirst
+class a64_u8q_nhwc_3x3_s1_output2x2_mla_depthfirst : public DepthwiseDepthfirstStrategy<uint8_t, uint8_t, uint8_t, int32_t>
 {
-  typedef int32_t bias_type;
-  typedef uint8_t input_type;
-  typedef uint8_t weight_type;
-  typedef uint8_t return_type;
+  using Parent = DepthwiseDepthfirstStrategy<uint8_t, uint8_t, uint8_t, int32_t>;
 
-  constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::None;
-
-  typedef void (*kern_type)(unsigned int, const uint8_t *const *, const uint8_t *, const int32_t *, const arm_gemm::Requantize32 &, const int32_t *, const int32_t *, uint8_t *const *);
-  typedef void (*parameter_packing_fn)(unsigned int, void *, const uint8_t *, size_t, size_t);
-  typedef size_t (*parameter_sizing_fn)(const DepthwiseArgs &);
-
+  public:
   constexpr static unsigned int kernel_rows = 3;
   constexpr static unsigned int kernel_cols = 3;
 
   constexpr static unsigned int stride_rows = 1;
   constexpr static unsigned int stride_cols = 1;
 
-  constexpr static unsigned int output_rows = 2;
-  constexpr static unsigned int output_cols = 2;
+  a64_u8q_nhwc_3x3_s1_output2x2_mla_depthfirst(const CPUInfo *) : Parent(2, 2, 3, 3, 1, 1) {}
 
-  constexpr static unsigned int input_rows = 4;
-  constexpr static unsigned int input_cols = 4;
+  arm_gemm::VLType get_vl_type(void) const override { return arm_gemm::VLType::None; }
 
-  constexpr static parameter_packing_fn pack_parameters = interleave_a64_u8q_3x3_mla::pack_parameters;
-  constexpr static parameter_sizing_fn get_packed_size = interleave_a64_u8q_3x3_mla::get_packed_size;
-
-  kern_type kernel = a64_u8q_nhwc_3x3_s1_output2x2_mla_depthfirst_impl;
-
-  a64_u8q_nhwc_3x3_s1_output2x2_mla_depthfirst(const CPUInfo *) {}
+  Parent::KernelType kernel = a64_u8q_nhwc_3x3_s1_output2x2_mla_depthfirst_impl;
+  Parent::KernelType get_kernel(void) const override { return kernel; }
+  unsigned int get_accumulator_depth_vl(void) const override { return 2; }
 };
 
 }  // namespace depthwise
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8q_nhwc_3x3_s1_output2x2_mla_depthfirst/generic.cpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8q_nhwc_3x3_s1_output2x2_mla_depthfirst/generic.cpp
index b62ebb1..453f9cf 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8q_nhwc_3x3_s1_output2x2_mla_depthfirst/generic.cpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8q_nhwc_3x3_s1_output2x2_mla_depthfirst/generic.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -46,7 +46,7 @@
   struct Params
   {
     long unsigned int n_channels;
-    const uint8_t *weights;
+    const void *weights;
     const int32_t *bias;
     const arm_gemm::Requantize32 *requant;
     const int32_t *const requant_muls;
@@ -57,7 +57,7 @@
     Params(
       long unsigned int n_channels,
       const uint8_t *const *inptrs_raw,
-      const uint8_t *const weights,
+      const void *const weights,
       const int32_t *const bias,
       const arm_gemm::Requantize32 &qp,
       const int32_t *const requant_muls,
@@ -91,513 +91,497 @@
                       requant_muls, requant_shifts, outptrs);
 
   __asm__ __volatile__(
+    "ldr x19, [%x[params], %[offsetof_Params_requant]]\n"
     "ldr x8, [%x[params], %[offsetof_Params_n_channels]]\n"
-    "mov x17, #0x0\n"
-    "ldr x16, [%x[params], %[offsetof_Params_weights]]\n"
+    "add x24, x19, %[offsetof_Requantize32_a_offset]\n"
+    "add x23, x19, %[offsetof_Requantize32_b_offset]\n"
+    "ldr x22, [%x[params], %[offsetof_Params_outptrs]]\n"
+    "add x21, x19, %[offsetof_Requantize32_c_offset]\n"
+    "add x20, x19, %[offsetof_Requantize32_minval]\n"
+    "ldr x17, [%x[params], %[offsetof_Params_weights]]\n"
+    "add x19, x19, %[offsetof_Requantize32_maxval]\n"
+    "ld1r { v22.16b }, [x24]\n"
+    "ld1r { v12.16b }, [x23]\n"
+    "lsr x16, x8, #0x3\n"
+    "ld1r { v14.8h }, [x21]\n"
+    "ld1r { v17.8h }, [x20]\n"
     "mov x15, #0x0\n"
-    "ldr x22, [%x[params], %[offsetof_Params_requant]]\n"
-    "add x14, %x[params], %[offsetof_Params_inptrs]\n"
+    "mov x14, #0x0\n"
+    "ld1r { v15.8h }, [x19]\n"
     "ldr x13, [%x[params], %[offsetof_Params_requant_muls]]\n"
-    "lsr x12, x8, #0x3\n"
+    "add x12, %x[params], %[offsetof_Params_inptrs]\n"
     "ldr x11, [%x[params], %[offsetof_Params_requant_shifts]]\n"
-    "add x19, x22, %[offsetof_Requantize32_a_offset]\n"
-    "ldr x21, [%x[params], %[offsetof_Params_outptrs]]\n"
-    "add x20, x22, %[offsetof_Requantize32_b_offset]\n"
-    "ld1r { v14.16b }, [x19]\n"
-    "add x19, x22, %[offsetof_Requantize32_c_offset]\n"
-    "ld1r { v9.16b }, [x20]\n"
-    "add x20, x22, %[offsetof_Requantize32_minval]\n"
-    "ld1r { v15.4s }, [x19]\n"
-    "add x19, x22, %[offsetof_Requantize32_maxval]\n"
-    "ld1r { v24.4s }, [x20]\n"
-    "ld1r { v12.4s }, [x19]\n"
-    "ldp x10, x9, [x21, #0x0]\n"
-    "ldp x28, x27, [x21, #0x10]\n"
-    "cbz x12, 3f\n"
-    "subs x12, x12, #0x1\n"
+    "ldp x10, x9, [x22, #0x0]\n"
+    "ldp x28, x27, [x22, #0x10]\n"
+    "cbz x16, 3f\n"
     "ldr x19, [%x[params], %[offsetof_Params_bias]]\n"
     "ldr q13, [x19, #0x0]\n"
-    "mov v17.16b, v13.16b\n"
-    "ldr q19, [x19, #0x10]\n"
+    "subs x16, x16, #0x1\n"
+    "mov v19.16b, v13.16b\n"
+    "ldr q26, [x19, #0x10]\n"
     "add x19, x19, #0x20\n"
-    "mov v16.16b, v13.16b\n"
     "str x19, [%x[params], %[offsetof_Params_bias]]\n"
-    "mov v23.16b, v13.16b\n"
-    "ldr d0, [x16, #0x0]\n"
-    "usubl v0.8h, v0.8b, v9.8b\n"
-    "mov v25.16b, v19.16b\n"
-    "ldr d1, [x16, #0x8]\n"
-    "mov v21.16b, v19.16b\n"
-    "ldr d2, [x16, #0x10]\n"
-    "usubl v1.8h, v1.8b, v9.8b\n"
-    "mov v20.16b, v19.16b\n"
-    "ldr d3, [x16, #0x18]\n"
-    "ldr d4, [x16, #0x20]\n"
-    "usubl v2.8h, v2.8b, v9.8b\n"
-    "ldr d5, [x16, #0x28]\n"
-    "usubl v3.8h, v3.8b, v9.8b\n"
-    "ldr d6, [x16, #0x30]\n"
-    "ldr d7, [x16, #0x38]\n"
-    "usubl v4.8h, v4.8b, v9.8b\n"
-    "ldr d8, [x16, #0x40]\n"
-    "usubl v5.8h, v5.8b, v9.8b\n"
-    "ldp x23, x22, [x14, #0x0]\n"
-    "usubl v6.8h, v6.8b, v9.8b\n"
-    "ldp x21, x20, [x14, #0x10]\n"
-    "usubl v7.8h, v7.8b, v9.8b\n"
-    "usubl v8.8h, v8.8b, v9.8b\n"
-    "ldr x19, [x14, #0x20]\n"
-    "ldr d31, [x23, x17]\n"
-    "usubl v31.8h, v31.8b, v14.8b\n"
-    "ldr d30, [x22, x17]\n"
-    "ldr d29, [x21, x17]\n"
-    "usubl v30.8h, v30.8b, v14.8b\n"
-    "ldr d28, [x20, x17]\n"
-    "ldr d27, [x19, x17]\n"
-    "usubl v29.8h, v29.8b, v14.8b\n"
-    "usubl v28.8h, v28.8b, v14.8b\n"
-    "usubl v27.8h, v27.8b, v14.8b\n"
+    "ldr d0, [x17, #0x0]\n"
+    "ldr d1, [x17, #0x8]\n"
+    "ldr d2, [x17, #0x10]\n"
+    "mov v11.16b, v26.16b\n"
+    "mov v18.16b, v13.16b\n"
+    "ldr d3, [x17, #0x18]\n"
+    "ldr d4, [x17, #0x20]\n"
+    "mov v24.16b, v26.16b\n"
+    "mov v9.16b, v13.16b\n"
+    "ldr d5, [x17, #0x28]\n"
+    "ldr d6, [x17, #0x30]\n"
+    "mov v23.16b, v26.16b\n"
+    "usubl v0.8h, v0.8b, v12.8b\n"
+    "ldr d7, [x17, #0x38]\n"
+    "ldr d8, [x17, #0x40]\n"
+    "usubl v1.8h, v1.8b, v12.8b\n"
+    "usubl v2.8h, v2.8b, v12.8b\n"
+    "ldp x23, x22, [x12, #0x0]\n"
+    "ldp x21, x20, [x12, #0x10]\n"
+    "usubl v3.8h, v3.8b, v12.8b\n"
+    "usubl v4.8h, v4.8b, v12.8b\n"
+    "ldr x19, [x12, #0x20]\n"
+    "ldr d31, [x23, x15]\n"
+    "usubl v5.8h, v5.8b, v12.8b\n"
+    "usubl v6.8h, v6.8b, v12.8b\n"
+    "ldr d30, [x22, x15]\n"
+    "ldr d29, [x21, x15]\n"
+    "usubl v7.8h, v7.8b, v12.8b\n"
+    "usubl v8.8h, v8.8b, v12.8b\n"
+    "ldr d28, [x20, x15]\n"
+    "ldr d27, [x19, x15]\n"
+    "usubl v31.8h, v31.8b, v22.8b\n"
+    "usubl v30.8h, v30.8b, v22.8b\n"
+    "usubl v29.8h, v29.8b, v22.8b\n"
+    "usubl v28.8h, v28.8b, v22.8b\n"
+    "usubl v27.8h, v27.8b, v22.8b\n"
     "beq 2f\n"
     "1:"  // Loop
     "smlal v13.4s, v31.4h, v4.4h\n"
-    "ldr x21, [x14, #0x28]\n"
-    "add x16, x16, #0x48\n"
-    "smlal2 v19.4s, v31.8h, v4.8h\n"
-    "ldr x20, [x14, #0x30]\n"
-    "subs x12, x12, #0x1\n"
-    "smlal v17.4s, v31.4h, v3.4h\n"
-    "ldr x26, [x14, #0x38]\n"
-    "smlal2 v25.4s, v31.8h, v3.8h\n"
-    "ldr x25, [x14, #0x40]\n"
-    "smlal v16.4s, v31.4h, v1.4h\n"
-    "ldr x19, [x14, #0x48]\n"
-    "smlal2 v21.4s, v31.8h, v1.8h\n"
-    "ldr x24, [x14, #0x50]\n"
-    "smlal v23.4s, v31.4h, v0.4h\n"
-    "ldr x23, [x14, #0x58]\n"
-    "smlal2 v20.4s, v31.8h, v0.8h\n"
-    "ldr d31, [x21, x17]\n"
-    "usubl v31.8h, v31.8b, v14.8b\n"
+    "smlal2 v26.4s, v31.8h, v4.8h\n"
+    "ldr x21, [x12, #0x28]\n"
+    "ldr x26, [x12, #0x38]\n"
+    "smlal v19.4s, v31.4h, v3.4h\n"
+    "smlal2 v11.4s, v31.8h, v3.8h\n"
+    "ldr x20, [x12, #0x30]\n"
+    "ldr x25, [x12, #0x40]\n"
     "smlal v13.4s, v30.4h, v0.4h\n"
-    "ldr x22, [x14, #0x60]\n"
-    "smlal2 v19.4s, v30.8h, v0.8h\n"
-    "ldr d30, [x19, x17]\n"
-    "usubl v30.8h, v30.8b, v14.8b\n"
-    "smlal v17.4s, v29.4h, v2.4h\n"
-    "ldr x21, [x14, #0x68]\n"
-    "smlal2 v25.4s, v29.8h, v2.8h\n"
-    "ldr d29, [x20, x17]\n"
-    "usubl v29.8h, v29.8b, v14.8b\n"
+    "smlal2 v26.4s, v30.8h, v0.8h\n"
+    "ldr x19, [x12, #0x48]\n"
+    "ldr d30, [x19, x15]\n"
+    "smlal v19.4s, v29.4h, v2.4h\n"
+    "smlal2 v11.4s, v29.8h, v2.8h\n"
+    "ldr d29, [x20, x15]\n"
+    "usubl v29.8h, v29.8b, v22.8b\n"
+    "smlal v18.4s, v31.4h, v1.4h\n"
+    "smlal2 v24.4s, v31.8h, v1.8h\n"
+    "ldr x24, [x12, #0x50]\n"
+    "ldr x23, [x12, #0x58]\n"
+    "smlal v9.4s, v31.4h, v0.4h\n"
+    "smlal2 v23.4s, v31.8h, v0.8h\n"
+    "ldr d31, [x21, x15]\n"
+    "usubl v31.8h, v31.8b, v22.8b\n"
     "smlal v13.4s, v28.4h, v5.4h\n"
-    "ldr x20, [x14, #0x70]\n"
-    "smlal2 v19.4s, v28.8h, v5.8h\n"
-    "ldr x19, [x14, #0x78]\n"
-    "smlal v17.4s, v28.4h, v4.4h\n"
-    "ldr q26, [x13, #0x0]\n"
-    "smlal2 v25.4s, v28.8h, v4.8h\n"
-    "ldr q10, [x11, #0x0]\n"
-    "smlal v16.4s, v28.4h, v2.4h\n"
-    "ldr q11, [x13, #0x10]\n"
-    "add x13, x13, #0x20\n"
-    "smlal2 v21.4s, v28.8h, v2.8h\n"
-    "ldr q18, [x11, #0x10]\n"
-    "add x11, x11, #0x20\n"
-    "smlal v23.4s, v28.4h, v1.4h\n"
-    "smlal2 v20.4s, v28.8h, v1.8h\n"
-    "ldr d28, [x26, x17]\n"
-    "usubl v28.8h, v28.8b, v14.8b\n"
-    "smlal v16.4s, v31.4h, v6.4h\n"
-    "smlal2 v21.4s, v31.8h, v6.8h\n"
-    "ldr d31, [x25, x17]\n"
-    "usubl v31.8h, v31.8b, v14.8b\n"
+    "smlal2 v26.4s, v28.8h, v5.8h\n"
+    "usubl v30.8h, v30.8b, v22.8b\n"
+    "ldr x22, [x12, #0x60]\n"
+    "smlal v19.4s, v28.4h, v4.4h\n"
+    "smlal2 v11.4s, v28.8h, v4.8h\n"
+    "ldr x21, [x12, #0x68]\n"
+    "ldr x20, [x12, #0x70]\n"
+    "smlal v18.4s, v28.4h, v2.4h\n"
+    "smlal2 v24.4s, v28.8h, v2.8h\n"
+    "ldr x19, [x12, #0x78]\n"
+    "ldr q21, [x13, #0x0]\n"
+    "smlal v9.4s, v28.4h, v1.4h\n"
+    "smlal2 v23.4s, v28.8h, v1.8h\n"
+    "ldr d28, [x26, x15]\n"
+    "usubl v28.8h, v28.8b, v22.8b\n"
     "smlal v13.4s, v27.4h, v7.4h\n"
-    "smlal2 v19.4s, v27.8h, v7.8h\n"
-    "smlal v17.4s, v27.4h, v6.4h\n"
-    "smlal2 v25.4s, v27.8h, v6.8h\n"
-    "smlal v16.4s, v27.4h, v4.4h\n"
-    "smlal2 v21.4s, v27.8h, v4.8h\n"
-    "smlal v23.4s, v27.4h, v3.4h\n"
-    "smlal2 v20.4s, v27.8h, v3.8h\n"
+    "smlal2 v26.4s, v27.8h, v7.8h\n"
+    "ldr q25, [x11, #0x0]\n"
+    "ldr q10, [x13, #0x10]\n"
+    "smlal v19.4s, v27.4h, v6.4h\n"
+    "smlal2 v11.4s, v27.8h, v6.8h\n"
+    "ldr q16, [x11, #0x10]\n"
+    "add x17, x17, #0x48\n"
+    "smlal v18.4s, v31.4h, v6.4h\n"
+    "smlal2 v24.4s, v31.8h, v6.8h\n"
+    "ldr d31, [x25, x15]\n"
+    "usubl v31.8h, v31.8b, v22.8b\n"
+    "smlal v9.4s, v27.4h, v3.4h\n"
+    "smlal2 v23.4s, v27.8h, v3.8h\n"
+    "subs x16, x16, #0x1\n"
+    "add x13, x13, #0x20\n"
     "smlal v13.4s, v28.4h, v1.4h\n"
-    "smlal2 v19.4s, v28.8h, v1.8h\n"
-    "smlal v23.4s, v29.4h, v8.4h\n"
-    "smlal2 v20.4s, v29.8h, v8.8h\n"
-    "ldr d29, [x24, x17]\n"
-    "usubl v29.8h, v29.8b, v14.8b\n"
-    "smlal v17.4s, v28.4h, v0.4h\n"
-    "smlal2 v25.4s, v28.8h, v0.8h\n"
-    "ldr d28, [x23, x17]\n"
-    "usubl v28.8h, v28.8b, v14.8b\n"
+    "smlal2 v26.4s, v28.8h, v1.8h\n"
+    "add x11, x11, #0x20\n"
+    "smlal v19.4s, v28.4h, v0.4h\n"
+    "smlal2 v11.4s, v28.8h, v0.8h\n"
+    "ldr d28, [x23, x15]\n"
+    "usubl v28.8h, v28.8b, v22.8b\n"
+    "smlal v18.4s, v27.4h, v4.4h\n"
+    "smlal v9.4s, v29.4h, v8.4h\n"
+    "smlal2 v24.4s, v27.8h, v4.8h\n"
+    "smlal2 v23.4s, v29.8h, v8.8h\n"
+    "ldr d29, [x24, x15]\n"
+    "usubl v29.8h, v29.8b, v22.8b\n"
     "smlal v13.4s, v31.4h, v2.4h\n"
-    "smlal2 v19.4s, v31.8h, v2.8h\n"
-    "smlal v17.4s, v31.4h, v1.4h\n"
-    "smlal2 v25.4s, v31.8h, v1.8h\n"
-    "ldr d31, [x22, x17]\n"
-    "usubl v31.8h, v31.8b, v14.8b\n"
+    "smlal2 v26.4s, v31.8h, v2.8h\n"
+    "smlal v19.4s, v31.4h, v1.4h\n"
+    "smlal2 v11.4s, v31.8h, v1.8h\n"
+    "ldr d31, [x22, x15]\n"
+    "usubl v31.8h, v31.8b, v22.8b\n"
+    "smlal v18.4s, v30.4h, v5.4h\n"
+    "smlal v9.4s, v30.4h, v4.4h\n"
     "smlal v13.4s, v30.4h, v8.4h\n"
-    "smlal2 v19.4s, v30.8h, v8.8h\n"
-    "smlal v17.4s, v30.4h, v7.4h\n"
-    "smlal2 v25.4s, v30.8h, v7.8h\n"
-    "smlal v16.4s, v30.4h, v5.4h\n"
-    "smlal2 v21.4s, v30.8h, v5.8h\n"
-    "smlal v23.4s, v30.4h, v4.4h\n"
-    "smlal2 v20.4s, v30.8h, v4.8h\n"
-    "ldr d30, [x21, x17]\n"
-    "usubl v30.8h, v30.8b, v14.8b\n"
+    "smlal2 v26.4s, v30.8h, v8.8h\n"
+    "smlal v19.4s, v30.4h, v7.4h\n"
+    "smlal2 v11.4s, v30.8h, v7.8h\n"
+    "smlal2 v24.4s, v30.8h, v5.8h\n"
+    "smlal2 v23.4s, v30.8h, v4.8h\n"
+    "ldr d30, [x21, x15]\n"
+    "usubl v30.8h, v30.8b, v22.8b\n"
+    "smlal v18.4s, v29.4h, v0.4h\n"
+    "smlal v9.4s, v28.4h, v2.4h\n"
     "smlal v13.4s, v29.4h, v3.4h\n"
-    "smlal2 v19.4s, v29.8h, v3.8h\n"
-    "smlal v16.4s, v29.4h, v0.4h\n"
-    "smlal2 v21.4s, v29.8h, v0.8h\n"
-    "ldr d29, [x20, x17]\n"
-    "usubl v29.8h, v29.8b, v14.8b\n"
-    "smlal v17.4s, v28.4h, v5.4h\n"
-    "smlal2 v25.4s, v28.8h, v5.8h\n"
-    "smlal v23.4s, v28.4h, v2.4h\n"
-    "smlal2 v20.4s, v28.8h, v2.8h\n"
-    "ldr d28, [x19, x17]\n"
-    "add x17, x17, #0x8\n"
-    "smlal v13.4s, v31.4h, v6.4h\n"
-    "usubl v28.8h, v28.8b, v14.8b\n"
-    "smlal2 v19.4s, v31.8h, v6.8h\n"
-    "smlal v16.4s, v31.4h, v3.4h\n"
-    "smlal2 v21.4s, v31.8h, v3.8h\n"
-    "smlal v17.4s, v30.4h, v8.4h\n"
-    "smlal2 v25.4s, v30.8h, v8.8h\n"
-    "smlal v23.4s, v30.4h, v5.4h\n"
-    "smlal2 v20.4s, v30.8h, v5.8h\n"
-    "smlal v16.4s, v29.4h, v7.4h\n"
-    "smlal2 v21.4s, v29.8h, v7.8h\n"
-    "smlal v23.4s, v29.4h, v6.4h\n"
-    "smlal2 v20.4s, v29.8h, v6.8h\n"
-    "smlal v16.4s, v28.4h, v8.4h\n"
-    "smlal2 v21.4s, v28.8h, v8.8h\n"
-    "smlal v23.4s, v28.4h, v7.4h\n"
-    "smlal2 v20.4s, v28.8h, v7.8h\n"
-    "sqrdmulh v13.4s, v13.4s, v26.4s\n"
-    "sqrdmulh v19.4s, v19.4s, v11.4s\n"
-    "sqrdmulh v17.4s, v17.4s, v26.4s\n"
-    "sqrdmulh v25.4s, v25.4s, v11.4s\n"
-    "and v22.16b, v13.16b, v10.16b\n"
-    "sshr v22.4s, v22.4s, #0x1f\n"
-    "and v28.16b, v19.16b, v18.16b\n"
-    "and v3.16b, v17.16b, v10.16b\n"
-    "sshr v28.4s, v28.4s, #0x1f\n"
-    "and v6.16b, v25.16b, v18.16b\n"
-    "sqrdmulh v16.4s, v16.4s, v26.4s\n"
-    "sshr v3.4s, v3.4s, #0x1f\n"
-    "sqrdmulh v21.4s, v21.4s, v11.4s\n"
-    "sshr v6.4s, v6.4s, #0x1f\n"
-    "sqadd v13.4s, v13.4s, v22.4s\n"
-    "sqrdmulh v23.4s, v23.4s, v26.4s\n"
-    "and v0.16b, v16.16b, v10.16b\n"
-    "sshr v0.4s, v0.4s, #0x1f\n"
-    "srshl v13.4s, v13.4s, v10.4s\n"
-    "sqadd v19.4s, v19.4s, v28.4s\n"
-    "sqadd v17.4s, v17.4s, v3.4s\n"
-    "sqadd v25.4s, v25.4s, v6.4s\n"
-    "and v29.16b, v21.16b, v18.16b\n"
-    "sshr v29.4s, v29.4s, #0x1f\n"
-    "add v13.4s, v13.4s, v15.4s\n"
-    "srshl v19.4s, v19.4s, v18.4s\n"
-    "srshl v17.4s, v17.4s, v10.4s\n"
-    "srshl v25.4s, v25.4s, v18.4s\n"
-    "smin v13.4s, v13.4s, v12.4s\n"
-    "add v19.4s, v19.4s, v15.4s\n"
-    "add v17.4s, v17.4s, v15.4s\n"
-    "smax v13.4s, v13.4s, v24.4s\n"
-    "smin v19.4s, v19.4s, v12.4s\n"
-    "smin v17.4s, v17.4s, v12.4s\n"
-    "add v25.4s, v25.4s, v15.4s\n"
-    "smax v19.4s, v19.4s, v24.4s\n"
-    "smax v17.4s, v17.4s, v24.4s\n"
-    "smin v25.4s, v25.4s, v12.4s\n"
-    "uzp1 v13.16b, v13.16b, v19.16b\n"
-    "sqadd v16.4s, v16.4s, v0.4s\n"
-    "uzp1 v13.16b, v13.16b, v13.16b\n"
-    "str d13, [x10, x15]\n"
-    "smax v25.4s, v25.4s, v24.4s\n"
-    "sqadd v21.4s, v21.4s, v29.4s\n"
-    "srshl v16.4s, v16.4s, v10.4s\n"
-    "and v3.16b, v23.16b, v10.16b\n"
-    "sshr v3.4s, v3.4s, #0x1f\n"
-    "uzp1 v17.16b, v17.16b, v25.16b\n"
-    "add v16.4s, v16.4s, v15.4s\n"
-    "srshl v21.4s, v21.4s, v18.4s\n"
-    "uzp1 v17.16b, v17.16b, v17.16b\n"
-    "str d17, [x9, x15]\n"
-    "smin v16.4s, v16.4s, v12.4s\n"
-    "sqrdmulh v20.4s, v20.4s, v11.4s\n"
-    "add v21.4s, v21.4s, v15.4s\n"
-    "sqadd v23.4s, v23.4s, v3.4s\n"
-    "smax v16.4s, v16.4s, v24.4s\n"
-    "smin v21.4s, v21.4s, v12.4s\n"
-    "and v25.16b, v20.16b, v18.16b\n"
-    "sshr v25.4s, v25.4s, #0x1f\n"
-    "smax v21.4s, v21.4s, v24.4s\n"
-    "srshl v23.4s, v23.4s, v10.4s\n"
-    "uzp1 v16.16b, v16.16b, v21.16b\n"
-    "add v23.4s, v23.4s, v15.4s\n"
-    "uzp1 v16.16b, v16.16b, v16.16b\n"
-    "str d16, [x28, x15]\n"
-    "smin v23.4s, v23.4s, v12.4s\n"
-    "sqadd v20.4s, v20.4s, v25.4s\n"
-    "smax v23.4s, v23.4s, v24.4s\n"
-    "srshl v20.4s, v20.4s, v18.4s\n"
-    "add v20.4s, v20.4s, v15.4s\n"
-    "smin v20.4s, v20.4s, v12.4s\n"
-    "smax v20.4s, v20.4s, v24.4s\n"
-    "uzp1 v23.16b, v23.16b, v20.16b\n"
-    "uzp1 v23.16b, v23.16b, v23.16b\n"
-    "str d23, [x27, x15]\n"
+    "smlal2 v26.4s, v29.8h, v3.8h\n"
+    "smlal2 v24.4s, v29.8h, v0.8h\n"
+    "ldr d29, [x20, x15]\n"
+    "smlal2 v23.4s, v28.8h, v2.8h\n"
+    "usubl v29.8h, v29.8b, v22.8b\n"
+    "smlal v18.4s, v31.4h, v3.4h\n"
+    "smlal v9.4s, v30.4h, v5.4h\n"
+    "smlal v19.4s, v28.4h, v5.4h\n"
+    "smlal2 v11.4s, v28.8h, v5.8h\n"
+    "ldr d28, [x19, x15]\n"
+    "usubl v28.8h, v28.8b, v22.8b\n"
+    "smlal2 v24.4s, v31.8h, v3.8h\n"
+    "smlal2 v23.4s, v30.8h, v5.8h\n"
     "add x15, x15, #0x8\n"
+    "smlal v18.4s, v29.4h, v7.4h\n"
+    "smlal v9.4s, v29.4h, v6.4h\n"
+    "smlal2 v24.4s, v29.8h, v7.8h\n"
+    "smlal2 v23.4s, v29.8h, v6.8h\n"
+    "smlal v13.4s, v31.4h, v6.4h\n"
+    "smlal v19.4s, v30.4h, v8.4h\n"
+    "sqdmulh v13.4s, v13.4s, v21.4s\n"
+    "smlal v18.4s, v28.4h, v8.4h\n"
+    "smlal v9.4s, v28.4h, v7.4h\n"
+    "sqdmulh v19.4s, v19.4s, v21.4s\n"
+    "smlal2 v26.4s, v31.8h, v6.8h\n"
+    "smlal2 v11.4s, v30.8h, v8.8h\n"
+    "sqdmulh v18.4s, v18.4s, v21.4s\n"
+    "smlal2 v24.4s, v28.8h, v8.8h\n"
+    "smlal2 v23.4s, v28.8h, v7.8h\n"
+    "sqdmulh v9.4s, v9.4s, v21.4s\n"
+    "and v7.16b, v13.16b, v25.16b\n"
+    "sqdmulh v26.4s, v26.4s, v10.4s\n"
+    "and v4.16b, v19.16b, v25.16b\n"
+    "sqdmulh v11.4s, v11.4s, v10.4s\n"
+    "and v21.16b, v18.16b, v25.16b\n"
+    "sqdmulh v24.4s, v24.4s, v10.4s\n"
+    "and v20.16b, v9.16b, v25.16b\n"
+    "sqdmulh v23.4s, v23.4s, v10.4s\n"
+    "sshr v7.4s, v7.4s, #0x1f\n"
+    "and v29.16b, v26.16b, v16.16b\n"
+    "sshr v4.4s, v4.4s, #0x1f\n"
+    "and v10.16b, v11.16b, v16.16b\n"
+    "sshr v21.4s, v21.4s, #0x1f\n"
+    "and v31.16b, v24.16b, v16.16b\n"
+    "sshr v20.4s, v20.4s, #0x1f\n"
+    "and v30.16b, v23.16b, v16.16b\n"
+    "sqadd v13.4s, v13.4s, v7.4s\n"
+    "sshr v29.4s, v29.4s, #0x1f\n"
+    "sqadd v19.4s, v19.4s, v4.4s\n"
+    "sshr v10.4s, v10.4s, #0x1f\n"
+    "sqadd v18.4s, v18.4s, v21.4s\n"
+    "sshr v31.4s, v31.4s, #0x1f\n"
+    "sqadd v9.4s, v9.4s, v20.4s\n"
+    "sshr v30.4s, v30.4s, #0x1f\n"
+    "srshl v13.4s, v13.4s, v25.4s\n"
+    "sqadd v26.4s, v26.4s, v29.4s\n"
+    "srshl v19.4s, v19.4s, v25.4s\n"
+    "sqadd v11.4s, v11.4s, v10.4s\n"
+    "srshl v18.4s, v18.4s, v25.4s\n"
+    "sqadd v24.4s, v24.4s, v31.4s\n"
+    "srshl v9.4s, v9.4s, v25.4s\n"
+    "sqadd v23.4s, v23.4s, v30.4s\n"
+    "srshl v26.4s, v26.4s, v16.4s\n"
+    "sqxtn v13.4h, v13.4s\n"
+    "srshl v11.4s, v11.4s, v16.4s\n"
+    "sqxtn v19.4h, v19.4s\n"
+    "srshl v24.4s, v24.4s, v16.4s\n"
+    "sqxtn v18.4h, v18.4s\n"
+    "srshl v23.4s, v23.4s, v16.4s\n"
+    "sqxtn v9.4h, v9.4s\n"
+    "sqxtn2 v13.8h, v26.4s\n"
+    "sqxtn2 v19.8h, v11.4s\n"
+    "sqxtn2 v18.8h, v24.4s\n"
+    "sqxtn2 v9.8h, v23.4s\n"
+    "sqadd v13.8h, v13.8h, v14.8h\n"
+    "sqadd v19.8h, v19.8h, v14.8h\n"
+    "sqadd v18.8h, v18.8h, v14.8h\n"
+    "sqadd v9.8h, v9.8h, v14.8h\n"
+    "smax v13.8h, v13.8h, v17.8h\n"
+    "smax v19.8h, v19.8h, v17.8h\n"
+    "smax v18.8h, v18.8h, v17.8h\n"
+    "smax v9.8h, v9.8h, v17.8h\n"
+    "smin v13.8h, v13.8h, v15.8h\n"
+    "smin v19.8h, v19.8h, v15.8h\n"
+    "smin v18.8h, v18.8h, v15.8h\n"
+    "smin v9.8h, v9.8h, v15.8h\n"
+    "uzp1 v13.16b, v13.16b, v13.16b\n"
+    "uzp1 v19.16b, v19.16b, v19.16b\n"
+    "str d13, [x10, x14]\n"
+    "uzp1 v18.16b, v18.16b, v18.16b\n"
+    "uzp1 v9.16b, v9.16b, v9.16b\n"
+    "str d19, [x9, x14]\n"
+    "str d18, [x28, x14]\n"
+    "str d9, [x27, x14]\n"
     "ldr x19, [%x[params], %[offsetof_Params_bias]]\n"
     "ldr q13, [x19, #0x0]\n"
-    "mov v17.16b, v13.16b\n"
-    "ldr q19, [x19, #0x10]\n"
+    "add x14, x14, #0x8\n"
+    "ldr q26, [x19, #0x10]\n"
     "add x19, x19, #0x20\n"
-    "mov v16.16b, v13.16b\n"
     "str x19, [%x[params], %[offsetof_Params_bias]]\n"
-    "mov v23.16b, v13.16b\n"
-    "ldr d0, [x16, #0x0]\n"
-    "usubl v0.8h, v0.8b, v9.8b\n"
-    "mov v25.16b, v19.16b\n"
-    "ldr d1, [x16, #0x8]\n"
-    "mov v21.16b, v19.16b\n"
-    "ldr d2, [x16, #0x10]\n"
-    "usubl v1.8h, v1.8b, v9.8b\n"
-    "mov v20.16b, v19.16b\n"
-    "ldr d3, [x16, #0x18]\n"
-    "ldr d4, [x16, #0x20]\n"
-    "usubl v2.8h, v2.8b, v9.8b\n"
-    "ldr d5, [x16, #0x28]\n"
-    "usubl v3.8h, v3.8b, v9.8b\n"
-    "ldr d6, [x16, #0x30]\n"
-    "ldr d7, [x16, #0x38]\n"
-    "usubl v4.8h, v4.8b, v9.8b\n"
-    "ldr d8, [x16, #0x40]\n"
-    "usubl v5.8h, v5.8b, v9.8b\n"
-    "ldp x23, x22, [x14, #0x0]\n"
-    "usubl v6.8h, v6.8b, v9.8b\n"
-    "ldp x21, x20, [x14, #0x10]\n"
-    "usubl v7.8h, v7.8b, v9.8b\n"
-    "usubl v8.8h, v8.8b, v9.8b\n"
-    "ldr x19, [x14, #0x20]\n"
-    "ldr d31, [x23, x17]\n"
-    "usubl v31.8h, v31.8b, v14.8b\n"
-    "ldr d30, [x22, x17]\n"
-    "ldr d29, [x21, x17]\n"
-    "usubl v30.8h, v30.8b, v14.8b\n"
-    "ldr d28, [x20, x17]\n"
-    "ldr d27, [x19, x17]\n"
-    "usubl v29.8h, v29.8b, v14.8b\n"
-    "usubl v28.8h, v28.8b, v14.8b\n"
-    "usubl v27.8h, v27.8b, v14.8b\n"
+    "ldr d0, [x17, #0x0]\n"
+    "ldr d1, [x17, #0x8]\n"
+    "ldr d2, [x17, #0x10]\n"
+    "mov v19.16b, v13.16b\n"
+    "mov v11.16b, v26.16b\n"
+    "ldr d3, [x17, #0x18]\n"
+    "ldr d4, [x17, #0x20]\n"
+    "mov v18.16b, v13.16b\n"
+    "mov v24.16b, v26.16b\n"
+    "ldr d5, [x17, #0x28]\n"
+    "ldr d6, [x17, #0x30]\n"
+    "mov v9.16b, v13.16b\n"
+    "mov v23.16b, v26.16b\n"
+    "ldr d7, [x17, #0x38]\n"
+    "ldr d8, [x17, #0x40]\n"
+    "usubl v0.8h, v0.8b, v12.8b\n"
+    "usubl v1.8h, v1.8b, v12.8b\n"
+    "ldp x23, x22, [x12, #0x0]\n"
+    "ldp x21, x20, [x12, #0x10]\n"
+    "usubl v2.8h, v2.8b, v12.8b\n"
+    "usubl v3.8h, v3.8b, v12.8b\n"
+    "ldr x19, [x12, #0x20]\n"
+    "ldr d31, [x23, x15]\n"
+    "usubl v4.8h, v4.8b, v12.8b\n"
+    "usubl v5.8h, v5.8b, v12.8b\n"
+    "ldr d30, [x22, x15]\n"
+    "ldr d29, [x21, x15]\n"
+    "usubl v6.8h, v6.8b, v12.8b\n"
+    "usubl v7.8h, v7.8b, v12.8b\n"
+    "ldr d28, [x20, x15]\n"
+    "ldr d27, [x19, x15]\n"
+    "usubl v8.8h, v8.8b, v12.8b\n"
+    "usubl v31.8h, v31.8b, v22.8b\n"
+    "usubl v30.8h, v30.8b, v22.8b\n"
+    "usubl v29.8h, v29.8b, v22.8b\n"
+    "usubl v28.8h, v28.8b, v22.8b\n"
+    "usubl v27.8h, v27.8b, v22.8b\n"
     "bgt 1b\n"
     "2:"  // Tail
     "smlal v13.4s, v31.4h, v4.4h\n"
-    "ldr x21, [x14, #0x28]\n"
-    "tst x8, #0x7\n"
-    "smlal2 v19.4s, v31.8h, v4.8h\n"
-    "ldr x20, [x14, #0x30]\n"
-    "smlal v17.4s, v31.4h, v3.4h\n"
-    "ldr x26, [x14, #0x38]\n"
-    "smlal2 v25.4s, v31.8h, v3.8h\n"
-    "ldr x25, [x14, #0x40]\n"
-    "smlal v16.4s, v31.4h, v1.4h\n"
-    "ldr x19, [x14, #0x48]\n"
-    "smlal2 v21.4s, v31.8h, v1.8h\n"
-    "ldr x24, [x14, #0x50]\n"
-    "smlal v23.4s, v31.4h, v0.4h\n"
-    "ldr x23, [x14, #0x58]\n"
-    "smlal2 v20.4s, v31.8h, v0.8h\n"
-    "ldr d31, [x21, x17]\n"
-    "usubl v31.8h, v31.8b, v14.8b\n"
+    "smlal2 v26.4s, v31.8h, v4.8h\n"
+    "ldr x21, [x12, #0x28]\n"
+    "ldr x26, [x12, #0x38]\n"
+    "smlal v19.4s, v31.4h, v3.4h\n"
+    "smlal2 v11.4s, v31.8h, v3.8h\n"
+    "ldr x20, [x12, #0x30]\n"
+    "ldr x25, [x12, #0x40]\n"
     "smlal v13.4s, v30.4h, v0.4h\n"
-    "ldr x22, [x14, #0x60]\n"
-    "smlal2 v19.4s, v30.8h, v0.8h\n"
-    "ldr d30, [x19, x17]\n"
-    "usubl v30.8h, v30.8b, v14.8b\n"
-    "smlal v17.4s, v29.4h, v2.4h\n"
-    "ldr x21, [x14, #0x68]\n"
-    "smlal2 v25.4s, v29.8h, v2.8h\n"
-    "ldr d29, [x20, x17]\n"
-    "usubl v29.8h, v29.8b, v14.8b\n"
+    "smlal2 v26.4s, v30.8h, v0.8h\n"
+    "ldr x19, [x12, #0x48]\n"
+    "ldr d30, [x19, x15]\n"
+    "smlal v19.4s, v29.4h, v2.4h\n"
+    "smlal2 v11.4s, v29.8h, v2.8h\n"
+    "ldr d29, [x20, x15]\n"
+    "usubl v29.8h, v29.8b, v22.8b\n"
+    "smlal v18.4s, v31.4h, v1.4h\n"
+    "smlal2 v24.4s, v31.8h, v1.8h\n"
+    "ldr x24, [x12, #0x50]\n"
+    "ldr x23, [x12, #0x58]\n"
+    "smlal v9.4s, v31.4h, v0.4h\n"
+    "smlal2 v23.4s, v31.8h, v0.8h\n"
+    "ldr d31, [x21, x15]\n"
+    "usubl v31.8h, v31.8b, v22.8b\n"
     "smlal v13.4s, v28.4h, v5.4h\n"
-    "ldr x20, [x14, #0x70]\n"
-    "smlal2 v19.4s, v28.8h, v5.8h\n"
-    "ldr x19, [x14, #0x78]\n"
-    "smlal v17.4s, v28.4h, v4.4h\n"
-    "ldr q26, [x13, #0x0]\n"
-    "smlal2 v25.4s, v28.8h, v4.8h\n"
-    "ldr q10, [x11, #0x0]\n"
-    "smlal v16.4s, v28.4h, v2.4h\n"
-    "ldr q11, [x13, #0x10]\n"
-    "add x13, x13, #0x20\n"
-    "smlal2 v21.4s, v28.8h, v2.8h\n"
-    "ldr q18, [x11, #0x10]\n"
-    "add x11, x11, #0x20\n"
-    "smlal v23.4s, v28.4h, v1.4h\n"
-    "smlal2 v20.4s, v28.8h, v1.8h\n"
-    "ldr d28, [x26, x17]\n"
-    "usubl v28.8h, v28.8b, v14.8b\n"
-    "smlal v16.4s, v31.4h, v6.4h\n"
-    "smlal2 v21.4s, v31.8h, v6.8h\n"
-    "ldr d31, [x25, x17]\n"
-    "usubl v31.8h, v31.8b, v14.8b\n"
+    "smlal2 v26.4s, v28.8h, v5.8h\n"
+    "usubl v30.8h, v30.8b, v22.8b\n"
+    "ldr x22, [x12, #0x60]\n"
+    "smlal v19.4s, v28.4h, v4.4h\n"
+    "smlal2 v11.4s, v28.8h, v4.8h\n"
+    "ldr x21, [x12, #0x68]\n"
+    "ldr x20, [x12, #0x70]\n"
+    "smlal v18.4s, v28.4h, v2.4h\n"
+    "smlal2 v24.4s, v28.8h, v2.8h\n"
+    "ldr x19, [x12, #0x78]\n"
+    "ldr q21, [x13, #0x0]\n"
+    "smlal v9.4s, v28.4h, v1.4h\n"
+    "smlal2 v23.4s, v28.8h, v1.8h\n"
+    "ldr d28, [x26, x15]\n"
+    "usubl v28.8h, v28.8b, v22.8b\n"
     "smlal v13.4s, v27.4h, v7.4h\n"
-    "smlal2 v19.4s, v27.8h, v7.8h\n"
-    "smlal v17.4s, v27.4h, v6.4h\n"
-    "smlal2 v25.4s, v27.8h, v6.8h\n"
-    "smlal v16.4s, v27.4h, v4.4h\n"
-    "smlal2 v21.4s, v27.8h, v4.8h\n"
-    "smlal v23.4s, v27.4h, v3.4h\n"
-    "smlal2 v20.4s, v27.8h, v3.8h\n"
+    "smlal2 v26.4s, v27.8h, v7.8h\n"
+    "ldr q25, [x11, #0x0]\n"
+    "ldr q10, [x13, #0x10]\n"
+    "smlal v19.4s, v27.4h, v6.4h\n"
+    "smlal2 v11.4s, v27.8h, v6.8h\n"
+    "ldr q16, [x11, #0x10]\n"
+    "tst x8, #0x7\n"
+    "smlal v18.4s, v31.4h, v6.4h\n"
+    "smlal2 v24.4s, v31.8h, v6.8h\n"
+    "ldr d31, [x25, x15]\n"
+    "usubl v31.8h, v31.8b, v22.8b\n"
+    "smlal v9.4s, v27.4h, v3.4h\n"
+    "smlal2 v23.4s, v27.8h, v3.8h\n"
+    "add x13, x13, #0x20\n"
+    "add x11, x11, #0x20\n"
     "smlal v13.4s, v28.4h, v1.4h\n"
-    "smlal2 v19.4s, v28.8h, v1.8h\n"
-    "smlal v23.4s, v29.4h, v8.4h\n"
-    "smlal2 v20.4s, v29.8h, v8.8h\n"
-    "ldr d29, [x24, x17]\n"
-    "usubl v29.8h, v29.8b, v14.8b\n"
-    "smlal v17.4s, v28.4h, v0.4h\n"
-    "smlal2 v25.4s, v28.8h, v0.8h\n"
-    "ldr d28, [x23, x17]\n"
-    "usubl v28.8h, v28.8b, v14.8b\n"
+    "smlal2 v26.4s, v28.8h, v1.8h\n"
+    "smlal v19.4s, v28.4h, v0.4h\n"
+    "smlal2 v11.4s, v28.8h, v0.8h\n"
+    "ldr d28, [x23, x15]\n"
+    "usubl v28.8h, v28.8b, v22.8b\n"
+    "smlal v18.4s, v27.4h, v4.4h\n"
+    "smlal v9.4s, v29.4h, v8.4h\n"
+    "smlal2 v24.4s, v27.8h, v4.8h\n"
+    "smlal2 v23.4s, v29.8h, v8.8h\n"
+    "ldr d29, [x24, x15]\n"
+    "usubl v29.8h, v29.8b, v22.8b\n"
     "smlal v13.4s, v31.4h, v2.4h\n"
-    "smlal2 v19.4s, v31.8h, v2.8h\n"
-    "smlal v17.4s, v31.4h, v1.4h\n"
-    "smlal2 v25.4s, v31.8h, v1.8h\n"
-    "ldr d31, [x22, x17]\n"
-    "usubl v31.8h, v31.8b, v14.8b\n"
+    "smlal2 v26.4s, v31.8h, v2.8h\n"
+    "smlal v19.4s, v31.4h, v1.4h\n"
+    "smlal2 v11.4s, v31.8h, v1.8h\n"
+    "ldr d31, [x22, x15]\n"
+    "usubl v31.8h, v31.8b, v22.8b\n"
+    "smlal v18.4s, v30.4h, v5.4h\n"
+    "smlal v9.4s, v30.4h, v4.4h\n"
     "smlal v13.4s, v30.4h, v8.4h\n"
-    "smlal2 v19.4s, v30.8h, v8.8h\n"
-    "smlal v17.4s, v30.4h, v7.4h\n"
-    "smlal2 v25.4s, v30.8h, v7.8h\n"
-    "smlal v16.4s, v30.4h, v5.4h\n"
-    "smlal2 v21.4s, v30.8h, v5.8h\n"
-    "smlal v23.4s, v30.4h, v4.4h\n"
-    "smlal2 v20.4s, v30.8h, v4.8h\n"
-    "ldr d30, [x21, x17]\n"
-    "usubl v30.8h, v30.8b, v14.8b\n"
+    "smlal2 v26.4s, v30.8h, v8.8h\n"
+    "smlal v19.4s, v30.4h, v7.4h\n"
+    "smlal2 v11.4s, v30.8h, v7.8h\n"
+    "smlal2 v24.4s, v30.8h, v5.8h\n"
+    "smlal2 v23.4s, v30.8h, v4.8h\n"
+    "ldr d30, [x21, x15]\n"
+    "usubl v30.8h, v30.8b, v22.8b\n"
+    "smlal v18.4s, v29.4h, v0.4h\n"
+    "smlal v9.4s, v28.4h, v2.4h\n"
     "smlal v13.4s, v29.4h, v3.4h\n"
-    "smlal2 v19.4s, v29.8h, v3.8h\n"
-    "smlal v16.4s, v29.4h, v0.4h\n"
-    "smlal2 v21.4s, v29.8h, v0.8h\n"
-    "ldr d29, [x20, x17]\n"
-    "usubl v29.8h, v29.8b, v14.8b\n"
-    "smlal v17.4s, v28.4h, v5.4h\n"
-    "smlal2 v25.4s, v28.8h, v5.8h\n"
-    "smlal v23.4s, v28.4h, v2.4h\n"
-    "smlal2 v20.4s, v28.8h, v2.8h\n"
-    "ldr d28, [x19, x17]\n"
-    "add x17, x17, #0x8\n"
-    "smlal v13.4s, v31.4h, v6.4h\n"
-    "usubl v28.8h, v28.8b, v14.8b\n"
-    "smlal2 v19.4s, v31.8h, v6.8h\n"
-    "smlal v16.4s, v31.4h, v3.4h\n"
-    "smlal2 v21.4s, v31.8h, v3.8h\n"
-    "smlal v17.4s, v30.4h, v8.4h\n"
-    "smlal2 v25.4s, v30.8h, v8.8h\n"
-    "smlal v23.4s, v30.4h, v5.4h\n"
-    "smlal2 v20.4s, v30.8h, v5.8h\n"
-    "smlal v16.4s, v29.4h, v7.4h\n"
-    "smlal2 v21.4s, v29.8h, v7.8h\n"
-    "smlal v23.4s, v29.4h, v6.4h\n"
-    "smlal2 v20.4s, v29.8h, v6.8h\n"
-    "smlal v16.4s, v28.4h, v8.4h\n"
-    "smlal2 v21.4s, v28.8h, v8.8h\n"
-    "smlal v23.4s, v28.4h, v7.4h\n"
-    "smlal2 v20.4s, v28.8h, v7.8h\n"
-    "sqrdmulh v13.4s, v13.4s, v26.4s\n"
-    "sqrdmulh v19.4s, v19.4s, v11.4s\n"
-    "sqrdmulh v17.4s, v17.4s, v26.4s\n"
-    "sqrdmulh v25.4s, v25.4s, v11.4s\n"
-    "and v22.16b, v13.16b, v10.16b\n"
-    "sshr v22.4s, v22.4s, #0x1f\n"
-    "and v28.16b, v19.16b, v18.16b\n"
-    "and v3.16b, v17.16b, v10.16b\n"
-    "sshr v28.4s, v28.4s, #0x1f\n"
-    "and v6.16b, v25.16b, v18.16b\n"
-    "sqrdmulh v16.4s, v16.4s, v26.4s\n"
-    "sshr v3.4s, v3.4s, #0x1f\n"
-    "sqrdmulh v21.4s, v21.4s, v11.4s\n"
-    "sshr v6.4s, v6.4s, #0x1f\n"
-    "sqadd v13.4s, v13.4s, v22.4s\n"
-    "sqrdmulh v23.4s, v23.4s, v26.4s\n"
-    "and v0.16b, v16.16b, v10.16b\n"
-    "sshr v0.4s, v0.4s, #0x1f\n"
-    "srshl v13.4s, v13.4s, v10.4s\n"
-    "sqadd v19.4s, v19.4s, v28.4s\n"
-    "sqadd v17.4s, v17.4s, v3.4s\n"
-    "sqadd v25.4s, v25.4s, v6.4s\n"
-    "and v29.16b, v21.16b, v18.16b\n"
-    "sshr v29.4s, v29.4s, #0x1f\n"
-    "add v13.4s, v13.4s, v15.4s\n"
-    "srshl v19.4s, v19.4s, v18.4s\n"
-    "srshl v17.4s, v17.4s, v10.4s\n"
-    "srshl v25.4s, v25.4s, v18.4s\n"
-    "smin v13.4s, v13.4s, v12.4s\n"
-    "add v19.4s, v19.4s, v15.4s\n"
-    "add v17.4s, v17.4s, v15.4s\n"
-    "smax v13.4s, v13.4s, v24.4s\n"
-    "smin v19.4s, v19.4s, v12.4s\n"
-    "smin v17.4s, v17.4s, v12.4s\n"
-    "add v25.4s, v25.4s, v15.4s\n"
-    "smax v19.4s, v19.4s, v24.4s\n"
-    "smax v17.4s, v17.4s, v24.4s\n"
-    "smin v25.4s, v25.4s, v12.4s\n"
-    "uzp1 v13.16b, v13.16b, v19.16b\n"
-    "sqadd v16.4s, v16.4s, v0.4s\n"
-    "uzp1 v13.16b, v13.16b, v13.16b\n"
-    "str d13, [x10, x15]\n"
-    "smax v25.4s, v25.4s, v24.4s\n"
-    "sqadd v21.4s, v21.4s, v29.4s\n"
-    "srshl v16.4s, v16.4s, v10.4s\n"
-    "and v3.16b, v23.16b, v10.16b\n"
-    "sshr v3.4s, v3.4s, #0x1f\n"
-    "uzp1 v17.16b, v17.16b, v25.16b\n"
-    "add v16.4s, v16.4s, v15.4s\n"
-    "srshl v21.4s, v21.4s, v18.4s\n"
-    "uzp1 v17.16b, v17.16b, v17.16b\n"
-    "str d17, [x9, x15]\n"
-    "smin v16.4s, v16.4s, v12.4s\n"
-    "sqrdmulh v20.4s, v20.4s, v11.4s\n"
-    "add v21.4s, v21.4s, v15.4s\n"
-    "sqadd v23.4s, v23.4s, v3.4s\n"
-    "smax v16.4s, v16.4s, v24.4s\n"
-    "smin v21.4s, v21.4s, v12.4s\n"
-    "and v25.16b, v20.16b, v18.16b\n"
-    "sshr v25.4s, v25.4s, #0x1f\n"
-    "smax v21.4s, v21.4s, v24.4s\n"
-    "srshl v23.4s, v23.4s, v10.4s\n"
-    "uzp1 v16.16b, v16.16b, v21.16b\n"
-    "add v23.4s, v23.4s, v15.4s\n"
-    "uzp1 v16.16b, v16.16b, v16.16b\n"
-    "str d16, [x28, x15]\n"
-    "smin v23.4s, v23.4s, v12.4s\n"
-    "sqadd v20.4s, v20.4s, v25.4s\n"
-    "smax v23.4s, v23.4s, v24.4s\n"
-    "srshl v20.4s, v20.4s, v18.4s\n"
-    "add v20.4s, v20.4s, v15.4s\n"
-    "smin v20.4s, v20.4s, v12.4s\n"
-    "smax v20.4s, v20.4s, v24.4s\n"
-    "uzp1 v23.16b, v23.16b, v20.16b\n"
-    "uzp1 v23.16b, v23.16b, v23.16b\n"
-    "str d23, [x27, x15]\n"
+    "smlal2 v26.4s, v29.8h, v3.8h\n"
+    "smlal2 v24.4s, v29.8h, v0.8h\n"
+    "ldr d29, [x20, x15]\n"
+    "smlal2 v23.4s, v28.8h, v2.8h\n"
+    "usubl v29.8h, v29.8b, v22.8b\n"
+    "smlal v18.4s, v31.4h, v3.4h\n"
+    "smlal v9.4s, v30.4h, v5.4h\n"
+    "smlal v19.4s, v28.4h, v5.4h\n"
+    "smlal2 v11.4s, v28.8h, v5.8h\n"
+    "ldr d28, [x19, x15]\n"
+    "usubl v28.8h, v28.8b, v22.8b\n"
+    "smlal2 v24.4s, v31.8h, v3.8h\n"
+    "smlal2 v23.4s, v30.8h, v5.8h\n"
     "add x15, x15, #0x8\n"
+    "smlal v18.4s, v29.4h, v7.4h\n"
+    "smlal v9.4s, v29.4h, v6.4h\n"
+    "smlal2 v24.4s, v29.8h, v7.8h\n"
+    "smlal2 v23.4s, v29.8h, v6.8h\n"
+    "smlal v13.4s, v31.4h, v6.4h\n"
+    "smlal v19.4s, v30.4h, v8.4h\n"
+    "sqdmulh v13.4s, v13.4s, v21.4s\n"
+    "smlal v18.4s, v28.4h, v8.4h\n"
+    "smlal v9.4s, v28.4h, v7.4h\n"
+    "sqdmulh v19.4s, v19.4s, v21.4s\n"
+    "smlal2 v26.4s, v31.8h, v6.8h\n"
+    "smlal2 v11.4s, v30.8h, v8.8h\n"
+    "sqdmulh v18.4s, v18.4s, v21.4s\n"
+    "smlal2 v24.4s, v28.8h, v8.8h\n"
+    "smlal2 v23.4s, v28.8h, v7.8h\n"
+    "sqdmulh v9.4s, v9.4s, v21.4s\n"
+    "and v7.16b, v13.16b, v25.16b\n"
+    "sqdmulh v26.4s, v26.4s, v10.4s\n"
+    "and v4.16b, v19.16b, v25.16b\n"
+    "sqdmulh v11.4s, v11.4s, v10.4s\n"
+    "and v21.16b, v18.16b, v25.16b\n"
+    "sqdmulh v24.4s, v24.4s, v10.4s\n"
+    "and v20.16b, v9.16b, v25.16b\n"
+    "sqdmulh v23.4s, v23.4s, v10.4s\n"
+    "sshr v7.4s, v7.4s, #0x1f\n"
+    "and v29.16b, v26.16b, v16.16b\n"
+    "sshr v4.4s, v4.4s, #0x1f\n"
+    "and v10.16b, v11.16b, v16.16b\n"
+    "sshr v21.4s, v21.4s, #0x1f\n"
+    "and v31.16b, v24.16b, v16.16b\n"
+    "sshr v20.4s, v20.4s, #0x1f\n"
+    "and v30.16b, v23.16b, v16.16b\n"
+    "sqadd v13.4s, v13.4s, v7.4s\n"
+    "sshr v29.4s, v29.4s, #0x1f\n"
+    "sqadd v19.4s, v19.4s, v4.4s\n"
+    "sshr v10.4s, v10.4s, #0x1f\n"
+    "sqadd v18.4s, v18.4s, v21.4s\n"
+    "sshr v31.4s, v31.4s, #0x1f\n"
+    "sqadd v9.4s, v9.4s, v20.4s\n"
+    "sshr v30.4s, v30.4s, #0x1f\n"
+    "srshl v13.4s, v13.4s, v25.4s\n"
+    "sqadd v26.4s, v26.4s, v29.4s\n"
+    "srshl v19.4s, v19.4s, v25.4s\n"
+    "sqadd v11.4s, v11.4s, v10.4s\n"
+    "srshl v18.4s, v18.4s, v25.4s\n"
+    "sqadd v24.4s, v24.4s, v31.4s\n"
+    "srshl v9.4s, v9.4s, v25.4s\n"
+    "sqadd v23.4s, v23.4s, v30.4s\n"
+    "srshl v26.4s, v26.4s, v16.4s\n"
+    "sqxtn v13.4h, v13.4s\n"
+    "srshl v11.4s, v11.4s, v16.4s\n"
+    "sqxtn v19.4h, v19.4s\n"
+    "srshl v24.4s, v24.4s, v16.4s\n"
+    "sqxtn v18.4h, v18.4s\n"
+    "srshl v23.4s, v23.4s, v16.4s\n"
+    "sqxtn v9.4h, v9.4s\n"
+    "sqxtn2 v13.8h, v26.4s\n"
+    "sqxtn2 v19.8h, v11.4s\n"
+    "sqxtn2 v18.8h, v24.4s\n"
+    "sqxtn2 v9.8h, v23.4s\n"
+    "sqadd v13.8h, v13.8h, v14.8h\n"
+    "sqadd v19.8h, v19.8h, v14.8h\n"
+    "sqadd v18.8h, v18.8h, v14.8h\n"
+    "sqadd v9.8h, v9.8h, v14.8h\n"
+    "smax v13.8h, v13.8h, v17.8h\n"
+    "smax v19.8h, v19.8h, v17.8h\n"
+    "smax v18.8h, v18.8h, v17.8h\n"
+    "smax v9.8h, v9.8h, v17.8h\n"
+    "smin v13.8h, v13.8h, v15.8h\n"
+    "smin v19.8h, v19.8h, v15.8h\n"
+    "smin v18.8h, v18.8h, v15.8h\n"
+    "smin v9.8h, v9.8h, v15.8h\n"
+    "uzp1 v13.16b, v13.16b, v13.16b\n"
+    "uzp1 v19.16b, v19.16b, v19.16b\n"
+    "str d13, [x10, x14]\n"
+    "uzp1 v18.16b, v18.16b, v18.16b\n"
+    "uzp1 v9.16b, v9.16b, v9.16b\n"
+    "str d19, [x9, x14]\n"
+    "str d18, [x28, x14]\n"
+    "str d9, [x27, x14]\n"
+    "add x14, x14, #0x8\n"
     "beq 64f\n"
-    "add x16, x16, #0x48\n"
+    "add x17, x17, #0x48\n"
     "3:"  // Oddments
     "ldr x19, [%x[params], %[offsetof_Params_bias]]\n"
     "tbz x8, #2, 5f\n"
     "ld1 { v13.4s }, [x19], #0x10\n"
     "tbz x8, #1, 4f\n"
-    "ld1 { v19.d }[0], [x19], #0x8\n"
+    "ld1 { v26.d }[0], [x19], #0x8\n"
     "tbz x8, #0, 7f\n"
-    "ld1 { v19.s }[2], [x19]\n"
+    "ld1 { v26.s }[2], [x19]\n"
     "b 7f\n"
     "4:"  // Oddments: Load bias: Bit 2: Bit 1: Unset
     "tbz x8, #0, 7f\n"
-    "ld1 { v19.s }[0], [x19]\n"
+    "ld1 { v26.s }[0], [x19]\n"
     "b 7f\n"
     "5:"  // Oddments: Load bias: Bit 2: Unset
     "tbz x8, #1, 6f\n"
@@ -609,38 +593,38 @@
     "tbz x8, #0, 7f\n"
     "ld1 { v13.s }[0], [x19]\n"
     "7:"  // Oddments: Load bias: Bit 2: End
-    "mov v17.16b, v13.16b\n"
-    "ldr d0, [x16, #0x0]\n"
-    "mov v25.16b, v19.16b\n"
-    "ldr d1, [x16, #0x8]\n"
-    "mov v16.16b, v13.16b\n"
-    "ldr d2, [x16, #0x10]\n"
-    "mov v21.16b, v19.16b\n"
-    "ldr d3, [x16, #0x18]\n"
-    "mov v23.16b, v13.16b\n"
-    "ldr d4, [x16, #0x20]\n"
-    "usubl v0.8h, v0.8b, v9.8b\n"
-    "mov v20.16b, v19.16b\n"
-    "ldr d5, [x16, #0x28]\n"
-    "usubl v1.8h, v1.8b, v9.8b\n"
-    "ldr d6, [x16, #0x30]\n"
-    "usubl v2.8h, v2.8b, v9.8b\n"
-    "ldr d7, [x16, #0x38]\n"
-    "usubl v3.8h, v3.8b, v9.8b\n"
-    "ldr d8, [x16, #0x40]\n"
-    "usubl v4.8h, v4.8b, v9.8b\n"
-    "ldp x23, x22, [x14, #0x0]\n"
-    "usubl v5.8h, v5.8b, v9.8b\n"
-    "ldp x21, x20, [x14, #0x10]\n"
-    "usubl v6.8h, v6.8b, v9.8b\n"
-    "usubl v7.8h, v7.8b, v9.8b\n"
-    "ldr x19, [x14, #0x20]\n"
-    "usubl v8.8h, v8.8b, v9.8b\n"
-    "add x23, x23, x17\n"
-    "add x22, x22, x17\n"
-    "add x21, x21, x17\n"
-    "add x20, x20, x17\n"
-    "add x19, x19, x17\n"
+    "ldr d0, [x17, #0x0]\n"
+    "ldr d1, [x17, #0x8]\n"
+    "mov v19.16b, v13.16b\n"
+    "mov v11.16b, v26.16b\n"
+    "ldr d2, [x17, #0x10]\n"
+    "ldr d3, [x17, #0x18]\n"
+    "mov v18.16b, v13.16b\n"
+    "mov v24.16b, v26.16b\n"
+    "ldr d4, [x17, #0x20]\n"
+    "ldr d5, [x17, #0x28]\n"
+    "mov v9.16b, v13.16b\n"
+    "mov v23.16b, v26.16b\n"
+    "ldr d6, [x17, #0x30]\n"
+    "ldr d7, [x17, #0x38]\n"
+    "usubl v0.8h, v0.8b, v12.8b\n"
+    "usubl v1.8h, v1.8b, v12.8b\n"
+    "ldr d8, [x17, #0x40]\n"
+    "ldp x23, x22, [x12, #0x0]\n"
+    "usubl v2.8h, v2.8b, v12.8b\n"
+    "usubl v3.8h, v3.8b, v12.8b\n"
+    "ldp x21, x20, [x12, #0x10]\n"
+    "ldr x19, [x12, #0x20]\n"
+    "usubl v4.8h, v4.8b, v12.8b\n"
+    "usubl v5.8h, v5.8b, v12.8b\n"
+    "usubl v6.8h, v6.8b, v12.8b\n"
+    "usubl v7.8h, v7.8b, v12.8b\n"
+    "usubl v8.8h, v8.8b, v12.8b\n"
+    "add x23, x23, x15\n"
+    "add x22, x22, x15\n"
+    "add x21, x21, x15\n"
+    "add x20, x20, x15\n"
+    "add x19, x19, x15\n"
     "tbz x8, #2, 9f\n"
     "ld1 { v31.s }[0], [x23], #0x4\n"
     "ld1 { v30.s }[0], [x22], #0x4\n"
@@ -690,33 +674,33 @@
     "ld1 { v28.b }[0], [x20]\n"
     "ld1 { v27.b }[0], [x19]\n"
     "11:"  // Oddments: Initial loads: Bit 2: End
-    "ldr x21, [x14, #0x28]\n"
-    "usubl v31.8h, v31.8b, v14.8b\n"
+    "usubl v31.8h, v31.8b, v22.8b\n"
     "smlal v13.4s, v31.4h, v4.4h\n"
-    "usubl v30.8h, v30.8b, v14.8b\n"
-    "smlal2 v19.4s, v31.8h, v4.8h\n"
-    "usubl v29.8h, v29.8b, v14.8b\n"
-    "smlal v17.4s, v31.4h, v3.4h\n"
-    "usubl v28.8h, v28.8b, v14.8b\n"
-    "smlal2 v25.4s, v31.8h, v3.8h\n"
-    "usubl v27.8h, v27.8b, v14.8b\n"
-    "smlal v16.4s, v31.4h, v1.4h\n"
-    "add x21, x21, x17\n"
-    "smlal2 v21.4s, v31.8h, v1.8h\n"
-    "smlal v23.4s, v31.4h, v0.4h\n"
-    "smlal2 v20.4s, v31.8h, v0.8h\n"
+    "smlal2 v26.4s, v31.8h, v4.8h\n"
+    "ldr x21, [x12, #0x28]\n"
+    "smlal v19.4s, v31.4h, v3.4h\n"
+    "smlal2 v11.4s, v31.8h, v3.8h\n"
+    "usubl v30.8h, v30.8b, v22.8b\n"
+    "add x21, x21, x15\n"
+    "usubl v29.8h, v29.8b, v22.8b\n"
+    "smlal v18.4s, v31.4h, v1.4h\n"
+    "smlal2 v24.4s, v31.8h, v1.8h\n"
+    "smlal v9.4s, v31.4h, v0.4h\n"
+    "smlal2 v23.4s, v31.8h, v0.8h\n"
+    "usubl v28.8h, v28.8b, v22.8b\n"
     "smlal v13.4s, v30.4h, v0.4h\n"
-    "smlal2 v19.4s, v30.8h, v0.8h\n"
-    "smlal v17.4s, v29.4h, v2.4h\n"
-    "smlal2 v25.4s, v29.8h, v2.8h\n"
+    "smlal2 v26.4s, v30.8h, v0.8h\n"
+    "usubl v27.8h, v27.8b, v22.8b\n"
+    "smlal v19.4s, v29.4h, v2.4h\n"
+    "smlal2 v11.4s, v29.8h, v2.8h\n"
     "smlal v13.4s, v28.4h, v5.4h\n"
-    "smlal2 v19.4s, v28.8h, v5.8h\n"
-    "smlal v17.4s, v28.4h, v4.4h\n"
-    "smlal2 v25.4s, v28.8h, v4.8h\n"
-    "smlal v16.4s, v28.4h, v2.4h\n"
-    "smlal2 v21.4s, v28.8h, v2.8h\n"
-    "smlal v23.4s, v28.4h, v1.4h\n"
-    "smlal2 v20.4s, v28.8h, v1.8h\n"
+    "smlal2 v26.4s, v28.8h, v5.8h\n"
+    "smlal v19.4s, v28.4h, v4.4h\n"
+    "smlal2 v11.4s, v28.8h, v4.8h\n"
+    "smlal v18.4s, v28.4h, v2.4h\n"
+    "smlal2 v24.4s, v28.8h, v2.8h\n"
+    "smlal v9.4s, v28.4h, v1.4h\n"
+    "smlal2 v23.4s, v28.8h, v1.8h\n"
     "tbz x8, #2, 13f\n"
     "ld1 { v31.s }[0], [x21], #0x4\n"
     "tbz x8, #1, 12f\n"
@@ -738,19 +722,19 @@
     "tbz x8, #0, 15f\n"
     "ld1 { v31.b }[0], [x21]\n"
     "15:"  // Oddments: Load (3, 0): Bit 2: End
+    "usubl v31.8h, v31.8b, v22.8b\n"
+    "smlal v18.4s, v31.4h, v6.4h\n"
+    "smlal2 v24.4s, v31.8h, v6.8h\n"
+    "ldr x20, [x12, #0x30]\n"
     "smlal v13.4s, v27.4h, v7.4h\n"
-    "ldr x20, [x14, #0x30]\n"
-    "usubl v31.8h, v31.8b, v14.8b\n"
-    "smlal2 v19.4s, v27.8h, v7.8h\n"
-    "smlal v17.4s, v27.4h, v6.4h\n"
-    "add x20, x20, x17\n"
-    "smlal2 v25.4s, v27.8h, v6.8h\n"
-    "smlal v23.4s, v27.4h, v3.4h\n"
-    "smlal2 v20.4s, v27.8h, v3.8h\n"
-    "smlal v16.4s, v31.4h, v6.4h\n"
-    "smlal2 v21.4s, v31.8h, v6.8h\n"
-    "smlal v16.4s, v27.4h, v4.4h\n"
-    "smlal2 v21.4s, v27.8h, v4.8h\n"
+    "smlal2 v26.4s, v27.8h, v7.8h\n"
+    "add x20, x20, x15\n"
+    "smlal v19.4s, v27.4h, v6.4h\n"
+    "smlal2 v11.4s, v27.8h, v6.8h\n"
+    "smlal v18.4s, v27.4h, v4.4h\n"
+    "smlal2 v24.4s, v27.8h, v4.8h\n"
+    "smlal v9.4s, v27.4h, v3.4h\n"
+    "smlal2 v23.4s, v27.8h, v3.8h\n"
     "tbz x8, #2, 17f\n"
     "ld1 { v29.s }[0], [x20], #0x4\n"
     "tbz x8, #1, 16f\n"
@@ -772,11 +756,11 @@
     "tbz x8, #0, 19f\n"
     "ld1 { v29.b }[0], [x20]\n"
     "19:"  // Oddments: Load (3, 3): Bit 2: End
-    "ldr x26, [x14, #0x38]\n"
-    "usubl v29.8h, v29.8b, v14.8b\n"
-    "smlal v23.4s, v29.4h, v8.4h\n"
-    "smlal2 v20.4s, v29.8h, v8.8h\n"
-    "add x26, x26, x17\n"
+    "usubl v29.8h, v29.8b, v22.8b\n"
+    "ldr x26, [x12, #0x38]\n"
+    "smlal v9.4s, v29.4h, v8.4h\n"
+    "smlal2 v23.4s, v29.8h, v8.8h\n"
+    "add x26, x26, x15\n"
     "tbz x8, #2, 21f\n"
     "ld1 { v28.s }[0], [x26], #0x4\n"
     "tbz x8, #1, 20f\n"
@@ -798,13 +782,13 @@
     "tbz x8, #0, 23f\n"
     "ld1 { v28.b }[0], [x26]\n"
     "23:"  // Oddments: Load (0, 1): Bit 2: End
-    "ldr x25, [x14, #0x40]\n"
-    "usubl v28.8h, v28.8b, v14.8b\n"
+    "usubl v28.8h, v28.8b, v22.8b\n"
+    "ldr x25, [x12, #0x40]\n"
     "smlal v13.4s, v28.4h, v1.4h\n"
-    "smlal2 v19.4s, v28.8h, v1.8h\n"
-    "add x25, x25, x17\n"
-    "smlal v17.4s, v28.4h, v0.4h\n"
-    "smlal2 v25.4s, v28.8h, v0.8h\n"
+    "smlal2 v26.4s, v28.8h, v1.8h\n"
+    "smlal v19.4s, v28.4h, v0.4h\n"
+    "smlal2 v11.4s, v28.8h, v0.8h\n"
+    "add x25, x25, x15\n"
     "tbz x8, #2, 25f\n"
     "ld1 { v31.s }[0], [x25], #0x4\n"
     "tbz x8, #1, 24f\n"
@@ -826,13 +810,13 @@
     "tbz x8, #0, 27f\n"
     "ld1 { v31.b }[0], [x25]\n"
     "27:"  // Oddments: Load (0, 2): Bit 2: End
-    "ldr x19, [x14, #0x48]\n"
-    "usubl v31.8h, v31.8b, v14.8b\n"
+    "usubl v31.8h, v31.8b, v22.8b\n"
+    "ldr x19, [x12, #0x48]\n"
     "smlal v13.4s, v31.4h, v2.4h\n"
-    "smlal2 v19.4s, v31.8h, v2.8h\n"
-    "add x19, x19, x17\n"
-    "smlal v17.4s, v31.4h, v1.4h\n"
-    "smlal2 v25.4s, v31.8h, v1.8h\n"
+    "smlal2 v26.4s, v31.8h, v2.8h\n"
+    "smlal v19.4s, v31.4h, v1.4h\n"
+    "smlal2 v11.4s, v31.8h, v1.8h\n"
+    "add x19, x19, x15\n"
     "tbz x8, #2, 29f\n"
     "ld1 { v30.s }[0], [x19], #0x4\n"
     "tbz x8, #1, 28f\n"
@@ -854,17 +838,17 @@
     "tbz x8, #0, 31f\n"
     "ld1 { v30.b }[0], [x19]\n"
     "31:"  // Oddments: Load (2, 2): Bit 2: End
-    "ldr x24, [x14, #0x50]\n"
-    "usubl v30.8h, v30.8b, v14.8b\n"
+    "usubl v30.8h, v30.8b, v22.8b\n"
+    "ldr x24, [x12, #0x50]\n"
     "smlal v13.4s, v30.4h, v8.4h\n"
-    "smlal2 v19.4s, v30.8h, v8.8h\n"
-    "add x24, x24, x17\n"
-    "smlal v17.4s, v30.4h, v7.4h\n"
-    "smlal2 v25.4s, v30.8h, v7.8h\n"
-    "smlal v16.4s, v30.4h, v5.4h\n"
-    "smlal2 v21.4s, v30.8h, v5.8h\n"
-    "smlal v23.4s, v30.4h, v4.4h\n"
-    "smlal2 v20.4s, v30.8h, v4.8h\n"
+    "smlal2 v26.4s, v30.8h, v8.8h\n"
+    "smlal v19.4s, v30.4h, v7.4h\n"
+    "smlal2 v11.4s, v30.8h, v7.8h\n"
+    "add x24, x24, x15\n"
+    "smlal v18.4s, v30.4h, v5.4h\n"
+    "smlal2 v24.4s, v30.8h, v5.8h\n"
+    "smlal v9.4s, v30.4h, v4.4h\n"
+    "smlal2 v23.4s, v30.8h, v4.8h\n"
     "tbz x8, #2, 33f\n"
     "ld1 { v29.s }[0], [x24], #0x4\n"
     "tbz x8, #1, 32f\n"
@@ -886,13 +870,13 @@
     "tbz x8, #0, 35f\n"
     "ld1 { v29.b }[0], [x24]\n"
     "35:"  // Oddments: Load (1, 0): Bit 2: End
-    "ldr x23, [x14, #0x58]\n"
-    "usubl v29.8h, v29.8b, v14.8b\n"
+    "usubl v29.8h, v29.8b, v22.8b\n"
+    "ldr x23, [x12, #0x58]\n"
     "smlal v13.4s, v29.4h, v3.4h\n"
-    "smlal2 v19.4s, v29.8h, v3.8h\n"
-    "add x23, x23, x17\n"
-    "smlal v16.4s, v29.4h, v0.4h\n"
-    "smlal2 v21.4s, v29.8h, v0.8h\n"
+    "smlal2 v26.4s, v29.8h, v3.8h\n"
+    "smlal v18.4s, v29.4h, v0.4h\n"
+    "smlal2 v24.4s, v29.8h, v0.8h\n"
+    "add x23, x23, x15\n"
     "tbz x8, #2, 37f\n"
     "ld1 { v28.s }[0], [x23], #0x4\n"
     "tbz x8, #1, 36f\n"
@@ -914,13 +898,13 @@
     "tbz x8, #0, 39f\n"
     "ld1 { v28.b }[0], [x23]\n"
     "39:"  // Oddments: Load (1, 3): Bit 2: End
-    "ldr x22, [x14, #0x60]\n"
-    "usubl v28.8h, v28.8b, v14.8b\n"
-    "smlal v17.4s, v28.4h, v5.4h\n"
-    "smlal2 v25.4s, v28.8h, v5.8h\n"
-    "add x22, x22, x17\n"
-    "smlal v23.4s, v28.4h, v2.4h\n"
-    "smlal2 v20.4s, v28.8h, v2.8h\n"
+    "usubl v28.8h, v28.8b, v22.8b\n"
+    "ldr x22, [x12, #0x60]\n"
+    "smlal v19.4s, v28.4h, v5.4h\n"
+    "smlal2 v11.4s, v28.8h, v5.8h\n"
+    "smlal v9.4s, v28.4h, v2.4h\n"
+    "smlal2 v23.4s, v28.8h, v2.8h\n"
+    "add x22, x22, x15\n"
     "tbz x8, #2, 41f\n"
     "ld1 { v31.s }[0], [x22], #0x4\n"
     "tbz x8, #1, 40f\n"
@@ -942,13 +926,13 @@
     "tbz x8, #0, 43f\n"
     "ld1 { v31.b }[0], [x22]\n"
     "43:"  // Oddments: Load (2, 0): Bit 2: End
-    "ldr x21, [x14, #0x68]\n"
-    "usubl v31.8h, v31.8b, v14.8b\n"
+    "usubl v31.8h, v31.8b, v22.8b\n"
+    "ldr x21, [x12, #0x68]\n"
     "smlal v13.4s, v31.4h, v6.4h\n"
-    "smlal2 v19.4s, v31.8h, v6.8h\n"
-    "add x21, x21, x17\n"
-    "smlal v16.4s, v31.4h, v3.4h\n"
-    "smlal2 v21.4s, v31.8h, v3.8h\n"
+    "smlal2 v26.4s, v31.8h, v6.8h\n"
+    "smlal v18.4s, v31.4h, v3.4h\n"
+    "smlal2 v24.4s, v31.8h, v3.8h\n"
+    "add x21, x21, x15\n"
     "tbz x8, #2, 45f\n"
     "ld1 { v30.s }[0], [x21], #0x4\n"
     "tbz x8, #1, 44f\n"
@@ -970,13 +954,13 @@
     "tbz x8, #0, 47f\n"
     "ld1 { v30.b }[0], [x21]\n"
     "47:"  // Oddments: Load (2, 3): Bit 2: End
-    "ldr x20, [x14, #0x70]\n"
-    "usubl v30.8h, v30.8b, v14.8b\n"
-    "smlal v17.4s, v30.4h, v8.4h\n"
-    "smlal2 v25.4s, v30.8h, v8.8h\n"
-    "add x20, x20, x17\n"
-    "smlal v23.4s, v30.4h, v5.4h\n"
-    "smlal2 v20.4s, v30.8h, v5.8h\n"
+    "usubl v30.8h, v30.8b, v22.8b\n"
+    "ldr x20, [x12, #0x70]\n"
+    "smlal v19.4s, v30.4h, v8.4h\n"
+    "smlal2 v11.4s, v30.8h, v8.8h\n"
+    "smlal v9.4s, v30.4h, v5.4h\n"
+    "smlal2 v23.4s, v30.8h, v5.8h\n"
+    "add x20, x20, x15\n"
     "tbz x8, #2, 49f\n"
     "ld1 { v29.s }[0], [x20], #0x4\n"
     "tbz x8, #1, 48f\n"
@@ -998,13 +982,13 @@
     "tbz x8, #0, 51f\n"
     "ld1 { v29.b }[0], [x20]\n"
     "51:"  // Oddments: Load (3, 1): Bit 2: End
-    "ldr x19, [x14, #0x78]\n"
-    "usubl v29.8h, v29.8b, v14.8b\n"
-    "smlal v16.4s, v29.4h, v7.4h\n"
-    "smlal2 v21.4s, v29.8h, v7.8h\n"
-    "add x19, x19, x17\n"
-    "smlal v23.4s, v29.4h, v6.4h\n"
-    "smlal2 v20.4s, v29.8h, v6.8h\n"
+    "usubl v29.8h, v29.8b, v22.8b\n"
+    "ldr x19, [x12, #0x78]\n"
+    "smlal v18.4s, v29.4h, v7.4h\n"
+    "smlal2 v24.4s, v29.8h, v7.8h\n"
+    "smlal v9.4s, v29.4h, v6.4h\n"
+    "smlal2 v23.4s, v29.8h, v6.8h\n"
+    "add x19, x19, x15\n"
     "tbz x8, #2, 53f\n"
     "ld1 { v28.s }[0], [x19], #0x4\n"
     "tbz x8, #1, 52f\n"
@@ -1026,160 +1010,150 @@
     "tbz x8, #0, 55f\n"
     "ld1 { v28.b }[0], [x19]\n"
     "55:"  // Oddments: Load (3, 2): Bit 2: End
-    "usubl v28.8h, v28.8b, v14.8b\n"
-    "smlal v16.4s, v28.4h, v8.4h\n"
-    "smlal2 v21.4s, v28.8h, v8.8h\n"
-    "smlal v23.4s, v28.4h, v7.4h\n"
-    "smlal2 v20.4s, v28.8h, v7.8h\n"
+    "usubl v28.8h, v28.8b, v22.8b\n"
+    "smlal v18.4s, v28.4h, v8.4h\n"
+    "smlal2 v24.4s, v28.8h, v8.8h\n"
+    "smlal v9.4s, v28.4h, v7.4h\n"
+    "smlal2 v23.4s, v28.8h, v7.8h\n"
     "tbz x8, #2, 57f\n"
-    "ld1 { v26.4s }, [x13], #0x10\n"
-    "ld1 { v10.4s }, [x11], #0x10\n"
+    "ld1 { v21.4s }, [x13], #0x10\n"
+    "ld1 { v25.4s }, [x11], #0x10\n"
     "tbz x8, #1, 56f\n"
-    "ld1 { v11.d }[0], [x13], #0x8\n"
-    "ld1 { v18.d }[0], [x11], #0x8\n"
+    "ld1 { v10.d }[0], [x13], #0x8\n"
+    "ld1 { v16.d }[0], [x11], #0x8\n"
     "tbz x8, #0, 59f\n"
-    "ld1 { v11.s }[2], [x13]\n"
-    "ld1 { v18.s }[2], [x11]\n"
+    "ld1 { v10.s }[2], [x13]\n"
+    "ld1 { v16.s }[2], [x11]\n"
     "b 59f\n"
     "56:"  // Oddments: Load requant params: Bit 2: Bit 1: Unset
     "tbz x8, #0, 59f\n"
-    "ld1 { v11.s }[0], [x13]\n"
-    "ld1 { v18.s }[0], [x11]\n"
+    "ld1 { v10.s }[0], [x13]\n"
+    "ld1 { v16.s }[0], [x11]\n"
     "b 59f\n"
     "57:"  // Oddments: Load requant params: Bit 2: Unset
     "tbz x8, #1, 58f\n"
-    "ld1 { v26.d }[0], [x13], #0x8\n"
-    "ld1 { v10.d }[0], [x11], #0x8\n"
+    "ld1 { v21.d }[0], [x13], #0x8\n"
+    "ld1 { v25.d }[0], [x11], #0x8\n"
     "tbz x8, #0, 59f\n"
-    "ld1 { v26.s }[2], [x13]\n"
-    "ld1 { v10.s }[2], [x11]\n"
+    "ld1 { v21.s }[2], [x13]\n"
+    "ld1 { v25.s }[2], [x11]\n"
     "b 59f\n"
     "58:"  // Oddments: Load requant params: Bit 2: Unset: Bit 1: Unset
     "tbz x8, #0, 59f\n"
-    "ld1 { v26.s }[0], [x13]\n"
-    "ld1 { v10.s }[0], [x11]\n"
+    "ld1 { v21.s }[0], [x13]\n"
+    "ld1 { v25.s }[0], [x11]\n"
     "59:"  // Oddments: Load requant params: Bit 2: End
-    "sqrdmulh v13.4s, v13.4s, v26.4s\n"
-    "add x10, x10, x15\n"
-    "sqrdmulh v19.4s, v19.4s, v11.4s\n"
-    "add x9, x9, x15\n"
-    "sqrdmulh v17.4s, v17.4s, v26.4s\n"
-    "add x28, x28, x15\n"
-    "sqrdmulh v25.4s, v25.4s, v11.4s\n"
-    "add x27, x27, x15\n"
-    "sqrdmulh v16.4s, v16.4s, v26.4s\n"
-    "and v22.16b, v13.16b, v10.16b\n"
-    "sshr v22.4s, v22.4s, #0x1f\n"
-    "and v28.16b, v19.16b, v18.16b\n"
-    "and v3.16b, v17.16b, v10.16b\n"
-    "sshr v28.4s, v28.4s, #0x1f\n"
-    "and v6.16b, v25.16b, v18.16b\n"
-    "and v0.16b, v16.16b, v10.16b\n"
-    "sshr v3.4s, v3.4s, #0x1f\n"
-    "sqrdmulh v21.4s, v21.4s, v11.4s\n"
-    "sshr v6.4s, v6.4s, #0x1f\n"
-    "sqadd v13.4s, v13.4s, v22.4s\n"
-    "sqrdmulh v23.4s, v23.4s, v26.4s\n"
-    "sshr v0.4s, v0.4s, #0x1f\n"
-    "sqrdmulh v20.4s, v20.4s, v11.4s\n"
-    "sqadd v19.4s, v19.4s, v28.4s\n"
-    "sqadd v17.4s, v17.4s, v3.4s\n"
-    "srshl v13.4s, v13.4s, v10.4s\n"
-    "sqadd v25.4s, v25.4s, v6.4s\n"
-    "srshl v19.4s, v19.4s, v18.4s\n"
-    "srshl v17.4s, v17.4s, v10.4s\n"
-    "add v13.4s, v13.4s, v15.4s\n"
-    "srshl v25.4s, v25.4s, v18.4s\n"
-    "add v19.4s, v19.4s, v15.4s\n"
-    "smin v13.4s, v13.4s, v12.4s\n"
-    "add v17.4s, v17.4s, v15.4s\n"
-    "smin v19.4s, v19.4s, v12.4s\n"
-    "smax v13.4s, v13.4s, v24.4s\n"
-    "smin v17.4s, v17.4s, v12.4s\n"
-    "smax v19.4s, v19.4s, v24.4s\n"
-    "add v25.4s, v25.4s, v15.4s\n"
-    "smax v17.4s, v17.4s, v24.4s\n"
-    "uzp1 v13.16b, v13.16b, v19.16b\n"
-    "smin v25.4s, v25.4s, v12.4s\n"
-    "uzp1 v13.16b, v13.16b, v13.16b\n"
-    "sqadd v16.4s, v16.4s, v0.4s\n"
-    "smax v25.4s, v25.4s, v24.4s\n"
-    "and v29.16b, v21.16b, v18.16b\n"
+    "sqdmulh v13.4s, v13.4s, v21.4s\n"
+    "sqdmulh v19.4s, v19.4s, v21.4s\n"
+    "add x10, x10, x14\n"
+    "add x9, x9, x14\n"
+    "sqdmulh v18.4s, v18.4s, v21.4s\n"
+    "sqdmulh v9.4s, v9.4s, v21.4s\n"
+    "add x28, x28, x14\n"
+    "add x27, x27, x14\n"
+    "and v7.16b, v13.16b, v25.16b\n"
+    "sqdmulh v26.4s, v26.4s, v10.4s\n"
+    "and v4.16b, v19.16b, v25.16b\n"
+    "sqdmulh v11.4s, v11.4s, v10.4s\n"
+    "and v21.16b, v18.16b, v25.16b\n"
+    "sqdmulh v24.4s, v24.4s, v10.4s\n"
+    "and v20.16b, v9.16b, v25.16b\n"
+    "sqdmulh v23.4s, v23.4s, v10.4s\n"
+    "sshr v7.4s, v7.4s, #0x1f\n"
+    "and v29.16b, v26.16b, v16.16b\n"
+    "sshr v4.4s, v4.4s, #0x1f\n"
+    "and v10.16b, v11.16b, v16.16b\n"
+    "sshr v21.4s, v21.4s, #0x1f\n"
+    "and v31.16b, v24.16b, v16.16b\n"
+    "sshr v20.4s, v20.4s, #0x1f\n"
+    "and v30.16b, v23.16b, v16.16b\n"
+    "sqadd v13.4s, v13.4s, v7.4s\n"
     "sshr v29.4s, v29.4s, #0x1f\n"
-    "uzp1 v17.16b, v17.16b, v25.16b\n"
-    "srshl v16.4s, v16.4s, v10.4s\n"
-    "and v3.16b, v23.16b, v10.16b\n"
-    "sshr v3.4s, v3.4s, #0x1f\n"
-    "uzp1 v17.16b, v17.16b, v17.16b\n"
-    "add v16.4s, v16.4s, v15.4s\n"
-    "sqadd v21.4s, v21.4s, v29.4s\n"
-    "and v25.16b, v20.16b, v18.16b\n"
-    "sshr v25.4s, v25.4s, #0x1f\n"
-    "smin v16.4s, v16.4s, v12.4s\n"
-    "srshl v21.4s, v21.4s, v18.4s\n"
-    "sqadd v23.4s, v23.4s, v3.4s\n"
-    "smax v16.4s, v16.4s, v24.4s\n"
-    "add v21.4s, v21.4s, v15.4s\n"
-    "srshl v23.4s, v23.4s, v10.4s\n"
-    "sqadd v20.4s, v20.4s, v25.4s\n"
-    "smin v21.4s, v21.4s, v12.4s\n"
-    "add v23.4s, v23.4s, v15.4s\n"
-    "srshl v20.4s, v20.4s, v18.4s\n"
-    "smax v21.4s, v21.4s, v24.4s\n"
-    "smin v23.4s, v23.4s, v12.4s\n"
-    "uzp1 v16.16b, v16.16b, v21.16b\n"
-    "add v20.4s, v20.4s, v15.4s\n"
-    "uzp1 v16.16b, v16.16b, v16.16b\n"
-    "smax v23.4s, v23.4s, v24.4s\n"
-    "smin v20.4s, v20.4s, v12.4s\n"
-    "smax v20.4s, v20.4s, v24.4s\n"
-    "uzp1 v23.16b, v23.16b, v20.16b\n"
-    "uzp1 v23.16b, v23.16b, v23.16b\n"
+    "sqadd v19.4s, v19.4s, v4.4s\n"
+    "sshr v10.4s, v10.4s, #0x1f\n"
+    "sqadd v18.4s, v18.4s, v21.4s\n"
+    "sshr v31.4s, v31.4s, #0x1f\n"
+    "sqadd v9.4s, v9.4s, v20.4s\n"
+    "sshr v30.4s, v30.4s, #0x1f\n"
+    "srshl v13.4s, v13.4s, v25.4s\n"
+    "sqadd v26.4s, v26.4s, v29.4s\n"
+    "srshl v19.4s, v19.4s, v25.4s\n"
+    "sqadd v11.4s, v11.4s, v10.4s\n"
+    "srshl v18.4s, v18.4s, v25.4s\n"
+    "sqadd v24.4s, v24.4s, v31.4s\n"
+    "srshl v9.4s, v9.4s, v25.4s\n"
+    "sqadd v23.4s, v23.4s, v30.4s\n"
+    "srshl v26.4s, v26.4s, v16.4s\n"
+    "sqxtn v13.4h, v13.4s\n"
+    "srshl v11.4s, v11.4s, v16.4s\n"
+    "sqxtn v19.4h, v19.4s\n"
+    "srshl v24.4s, v24.4s, v16.4s\n"
+    "sqxtn v18.4h, v18.4s\n"
+    "srshl v23.4s, v23.4s, v16.4s\n"
+    "sqxtn v9.4h, v9.4s\n"
+    "sqxtn2 v13.8h, v26.4s\n"
+    "sqxtn2 v19.8h, v11.4s\n"
+    "sqxtn2 v18.8h, v24.4s\n"
+    "sqxtn2 v9.8h, v23.4s\n"
+    "sqadd v13.8h, v13.8h, v14.8h\n"
+    "sqadd v19.8h, v19.8h, v14.8h\n"
+    "sqadd v18.8h, v18.8h, v14.8h\n"
+    "sqadd v9.8h, v9.8h, v14.8h\n"
+    "smax v13.8h, v13.8h, v17.8h\n"
+    "smax v19.8h, v19.8h, v17.8h\n"
+    "smax v18.8h, v18.8h, v17.8h\n"
+    "smax v9.8h, v9.8h, v17.8h\n"
+    "smin v13.8h, v13.8h, v15.8h\n"
+    "smin v19.8h, v19.8h, v15.8h\n"
+    "smin v18.8h, v18.8h, v15.8h\n"
+    "smin v9.8h, v9.8h, v15.8h\n"
+    "uzp1 v13.16b, v13.16b, v13.16b\n"
+    "uzp1 v19.16b, v19.16b, v19.16b\n"
+    "uzp1 v18.16b, v18.16b, v18.16b\n"
+    "uzp1 v9.16b, v9.16b, v9.16b\n"
     "tbz x8, #2, 61f\n"
     "st1 { v13.s }[0], [x10], #0x4\n"
-    "st1 { v17.s }[0], [x9], #0x4\n"
-    "st1 { v16.s }[0], [x28], #0x4\n"
-    "st1 { v23.s }[0], [x27], #0x4\n"
+    "st1 { v19.s }[0], [x9], #0x4\n"
+    "st1 { v18.s }[0], [x28], #0x4\n"
+    "st1 { v9.s }[0], [x27], #0x4\n"
     "tbz x8, #1, 60f\n"
     "st1 { v13.h }[2], [x10], #0x2\n"
-    "st1 { v17.h }[2], [x9], #0x2\n"
-    "st1 { v16.h }[2], [x28], #0x2\n"
-    "st1 { v23.h }[2], [x27], #0x2\n"
+    "st1 { v19.h }[2], [x9], #0x2\n"
+    "st1 { v18.h }[2], [x28], #0x2\n"
+    "st1 { v9.h }[2], [x27], #0x2\n"
     "tbz x8, #0, 63f\n"
     "st1 { v13.b }[6], [x10], #0x1\n"
-    "st1 { v17.b }[6], [x9], #0x1\n"
-    "st1 { v16.b }[6], [x28], #0x1\n"
-    "st1 { v23.b }[6], [x27], #0x1\n"
+    "st1 { v19.b }[6], [x9], #0x1\n"
+    "st1 { v18.b }[6], [x28], #0x1\n"
+    "st1 { v9.b }[6], [x27], #0x1\n"
     "b 63f\n"
     "60:"  // Oddments: Bit 2: Bit 1: Unset
     "tbz x8, #0, 63f\n"
     "st1 { v13.b }[4], [x10], #0x1\n"
-    "st1 { v17.b }[4], [x9], #0x1\n"
-    "st1 { v16.b }[4], [x28], #0x1\n"
-    "st1 { v23.b }[4], [x27], #0x1\n"
+    "st1 { v19.b }[4], [x9], #0x1\n"
+    "st1 { v18.b }[4], [x28], #0x1\n"
+    "st1 { v9.b }[4], [x27], #0x1\n"
     "b 63f\n"
     "61:"  // Oddments: Bit 2: Unset
     "tbz x8, #1, 62f\n"
     "st1 { v13.h }[0], [x10], #0x2\n"
-    "st1 { v17.h }[0], [x9], #0x2\n"
-    "st1 { v16.h }[0], [x28], #0x2\n"
-    "st1 { v23.h }[0], [x27], #0x2\n"
+    "st1 { v19.h }[0], [x9], #0x2\n"
+    "st1 { v18.h }[0], [x28], #0x2\n"
+    "st1 { v9.h }[0], [x27], #0x2\n"
     "tbz x8, #0, 63f\n"
     "st1 { v13.b }[2], [x10], #0x1\n"
-    "st1 { v17.b }[2], [x9], #0x1\n"
-    "st1 { v16.b }[2], [x28], #0x1\n"
-    "st1 { v23.b }[2], [x27], #0x1\n"
+    "st1 { v19.b }[2], [x9], #0x1\n"
+    "st1 { v18.b }[2], [x28], #0x1\n"
+    "st1 { v9.b }[2], [x27], #0x1\n"
     "b 63f\n"
     "62:"  // Oddments: Bit 2: Unset: Bit 1: Unset
     "tbz x8, #0, 63f\n"
     "st1 { v13.b }[0], [x10], #0x1\n"
-    "st1 { v17.b }[0], [x9], #0x1\n"
-    "st1 { v16.b }[0], [x28], #0x1\n"
-    "st1 { v23.b }[0], [x27], #0x1\n"
+    "st1 { v19.b }[0], [x9], #0x1\n"
+    "st1 { v18.b }[0], [x28], #0x1\n"
+    "st1 { v9.b }[0], [x27], #0x1\n"
     "63:"  // Oddments: Bit 2: End
-
     "64:"  // End
-
     :
     : [offsetof_Params_bias] "I" (offsetof(Params, bias)), [offsetof_Params_inptrs] "I" (offsetof(Params, inptrs)), [offsetof_Params_n_channels] "I" (offsetof(Params, n_channels)), [offsetof_Params_outptrs] "I" (offsetof(Params, outptrs)), [offsetof_Params_requant] "I" (offsetof(Params, requant)), [offsetof_Params_requant_muls] "I" (offsetof(Params, requant_muls)), [offsetof_Params_requant_shifts] "I" (offsetof(Params, requant_shifts)), [offsetof_Params_weights] "I" (offsetof(Params, weights)), [offsetof_Requantize32_a_offset] "I" (offsetof(arm_gemm::Requantize32, a_offset)), [offsetof_Requantize32_b_offset] "I" (offsetof(arm_gemm::Requantize32, b_offset)), [offsetof_Requantize32_c_offset] "I" (offsetof(arm_gemm::Requantize32, c_offset)), [offsetof_Requantize32_maxval] "I" (offsetof(arm_gemm::Requantize32, maxval)), [offsetof_Requantize32_minval] "I" (offsetof(arm_gemm::Requantize32, minval)), [params] "r" (&params)
     : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x8", "x9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17", "x19", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8q_nhwc_3x3_s2_output2x2_mla_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8q_nhwc_3x3_s2_output2x2_mla_depthfirst.hpp
index 44817db..00d1c5e 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8q_nhwc_3x3_s2_output2x2_mla_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8q_nhwc_3x3_s2_output2x2_mla_depthfirst.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -36,37 +36,24 @@
 
 void a64_u8q_nhwc_3x3_s2_output2x2_mla_depthfirst_impl(unsigned int, const uint8_t *const *, const uint8_t *, const int32_t *, const arm_gemm::Requantize32 &, const int32_t *, const int32_t *, uint8_t *const *);
 
-struct a64_u8q_nhwc_3x3_s2_output2x2_mla_depthfirst
+class a64_u8q_nhwc_3x3_s2_output2x2_mla_depthfirst : public DepthwiseDepthfirstStrategy<uint8_t, uint8_t, uint8_t, int32_t>
 {
-  typedef int32_t bias_type;
-  typedef uint8_t input_type;
-  typedef uint8_t weight_type;
-  typedef uint8_t return_type;
+  using Parent = DepthwiseDepthfirstStrategy<uint8_t, uint8_t, uint8_t, int32_t>;
 
-  constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::None;
-
-  typedef void (*kern_type)(unsigned int, const uint8_t *const *, const uint8_t *, const int32_t *, const arm_gemm::Requantize32 &, const int32_t *, const int32_t *, uint8_t *const *);
-  typedef void (*parameter_packing_fn)(unsigned int, void *, const uint8_t *, size_t, size_t);
-  typedef size_t (*parameter_sizing_fn)(const DepthwiseArgs &);
-
+  public:
   constexpr static unsigned int kernel_rows = 3;
   constexpr static unsigned int kernel_cols = 3;
 
   constexpr static unsigned int stride_rows = 2;
   constexpr static unsigned int stride_cols = 2;
 
-  constexpr static unsigned int output_rows = 2;
-  constexpr static unsigned int output_cols = 2;
+  a64_u8q_nhwc_3x3_s2_output2x2_mla_depthfirst(const CPUInfo *) : Parent(2, 2, 3, 3, 2, 2) {}
 
-  constexpr static unsigned int input_rows = 5;
-  constexpr static unsigned int input_cols = 5;
+  arm_gemm::VLType get_vl_type(void) const override { return arm_gemm::VLType::None; }
 
-  constexpr static parameter_packing_fn pack_parameters = interleave_a64_u8q_3x3_mla::pack_parameters;
-  constexpr static parameter_sizing_fn get_packed_size = interleave_a64_u8q_3x3_mla::get_packed_size;
-
-  kern_type kernel = a64_u8q_nhwc_3x3_s2_output2x2_mla_depthfirst_impl;
-
-  a64_u8q_nhwc_3x3_s2_output2x2_mla_depthfirst(const CPUInfo *) {}
+  Parent::KernelType kernel = a64_u8q_nhwc_3x3_s2_output2x2_mla_depthfirst_impl;
+  Parent::KernelType get_kernel(void) const override { return kernel; }
+  unsigned int get_accumulator_depth_vl(void) const override { return 2; }
 };
 
 }  // namespace depthwise
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8q_nhwc_3x3_s2_output2x2_mla_depthfirst/generic.cpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8q_nhwc_3x3_s2_output2x2_mla_depthfirst/generic.cpp
index 8d22836..872f665 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8q_nhwc_3x3_s2_output2x2_mla_depthfirst/generic.cpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8q_nhwc_3x3_s2_output2x2_mla_depthfirst/generic.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -46,7 +46,7 @@
   struct Params
   {
     long unsigned int n_channels;
-    const uint8_t *weights;
+    const void *weights;
     const int32_t *bias;
     const arm_gemm::Requantize32 *requant;
     const int32_t *const requant_muls;
@@ -57,7 +57,7 @@
     Params(
       long unsigned int n_channels,
       const uint8_t *const *inptrs_raw,
-      const uint8_t *const weights,
+      const void *const weights,
       const int32_t *const bias,
       const arm_gemm::Requantize32 &qp,
       const int32_t *const requant_muls,
@@ -100,75 +100,75 @@
                       requant_muls, requant_shifts, outptrs);
 
   __asm__ __volatile__(
-    "ldr x4, [%x[params], %[offsetof_Params_n_channels]]\n"
-    "mov x5, #0x0\n"
-    "ldr x6, [%x[params], %[offsetof_Params_weights]]\n"
-    "mov x7, #0x0\n"
-    "ldr x22, [%x[params], %[offsetof_Params_requant]]\n"
-    "add x8, %x[params], %[offsetof_Params_inptrs]\n"
-    "ldr x17, [%x[params], %[offsetof_Params_requant_muls]]\n"
-    "lsr x16, x4, #0x3\n"
-    "ldr x15, [%x[params], %[offsetof_Params_requant_shifts]]\n"
-    "add x19, x22, %[offsetof_Requantize32_a_offset]\n"
-    "ldr x21, [%x[params], %[offsetof_Params_outptrs]]\n"
-    "add x20, x22, %[offsetof_Requantize32_b_offset]\n"
-    "ld1r { v12.16b }, [x19]\n"
-    "add x19, x22, %[offsetof_Requantize32_c_offset]\n"
-    "ld1r { v13.16b }, [x20]\n"
-    "add x20, x22, %[offsetof_Requantize32_minval]\n"
-    "ld1r { v11.4s }, [x19]\n"
-    "add x19, x22, %[offsetof_Requantize32_maxval]\n"
-    "ld1r { v19.4s }, [x20]\n"
-    "ld1r { v14.4s }, [x19]\n"
-    "ldp x14, x13, [x21, #0x0]\n"
-    "ldp x12, x11, [x21, #0x10]\n"
+    "ldr x19, [%x[params], %[offsetof_Params_requant]]\n"
+    "ldr x8, [%x[params], %[offsetof_Params_n_channels]]\n"
+    "add x24, x19, %[offsetof_Requantize32_a_offset]\n"
+    "add x23, x19, %[offsetof_Requantize32_b_offset]\n"
+    "ldr x22, [%x[params], %[offsetof_Params_outptrs]]\n"
+    "add x21, x19, %[offsetof_Requantize32_c_offset]\n"
+    "add x20, x19, %[offsetof_Requantize32_minval]\n"
+    "ldr x17, [%x[params], %[offsetof_Params_weights]]\n"
+    "add x19, x19, %[offsetof_Requantize32_maxval]\n"
+    "ld1r { v12.16b }, [x24]\n"
+    "ld1r { v13.16b }, [x23]\n"
+    "lsr x16, x8, #0x3\n"
+    "ld1r { v11.8h }, [x21]\n"
+    "ld1r { v17.8h }, [x20]\n"
+    "mov x15, #0x0\n"
+    "mov x14, #0x0\n"
+    "ld1r { v14.8h }, [x19]\n"
+    "ldr x13, [%x[params], %[offsetof_Params_requant_muls]]\n"
+    "add x12, %x[params], %[offsetof_Params_inptrs]\n"
+    "ldr x11, [%x[params], %[offsetof_Params_requant_shifts]]\n"
+    "ldp x10, x9, [x22, #0x0]\n"
+    "ldp x28, x27, [x22, #0x10]\n"
     "cbz x16, 3f\n"
-    "subs x16, x16, #0x1\n"
     "ldr x19, [%x[params], %[offsetof_Params_bias]]\n"
     "ldr q15, [x19, #0x0]\n"
-    "mov v20.16b, v15.16b\n"
+    "subs x16, x16, #0x1\n"
+    "mov v9.16b, v15.16b\n"
     "ldr q10, [x19, #0x10]\n"
     "add x19, x19, #0x20\n"
-    "mov v16.16b, v15.16b\n"
     "str x19, [%x[params], %[offsetof_Params_bias]]\n"
-    "mov v17.16b, v15.16b\n"
-    "ldr d0, [x6, #0x0]\n"
-    "usubl v0.8h, v0.8b, v13.8b\n"
-    "mov v23.16b, v10.16b\n"
-    "ldr d1, [x6, #0x8]\n"
-    "mov v22.16b, v10.16b\n"
-    "ldr d2, [x6, #0x10]\n"
-    "usubl v1.8h, v1.8b, v13.8b\n"
+    "ldr d0, [x17, #0x0]\n"
+    "ldr d1, [x17, #0x8]\n"
+    "ldr d2, [x17, #0x10]\n"
+    "mov v16.16b, v10.16b\n"
+    "mov v22.16b, v15.16b\n"
+    "ldr d3, [x17, #0x18]\n"
+    "ldr d4, [x17, #0x20]\n"
+    "mov v21.16b, v10.16b\n"
+    "mov v23.16b, v15.16b\n"
+    "ldr d5, [x17, #0x28]\n"
+    "ldr d6, [x17, #0x30]\n"
     "mov v18.16b, v10.16b\n"
-    "ldr d3, [x6, #0x18]\n"
-    "ldr d4, [x6, #0x20]\n"
+    "usubl v0.8h, v0.8b, v13.8b\n"
+    "ldr d7, [x17, #0x38]\n"
+    "ldr d8, [x17, #0x40]\n"
+    "usubl v1.8h, v1.8b, v13.8b\n"
     "usubl v2.8h, v2.8b, v13.8b\n"
-    "ldr d5, [x6, #0x28]\n"
+    "ldp x26, x25, [x12, #0x0]\n"
+    "ldp x24, x23, [x12, #0x10]\n"
     "usubl v3.8h, v3.8b, v13.8b\n"
-    "ldr d6, [x6, #0x30]\n"
-    "ldr d7, [x6, #0x38]\n"
     "usubl v4.8h, v4.8b, v13.8b\n"
-    "ldr d8, [x6, #0x40]\n"
+    "ldp x22, x21, [x12, #0x20]\n"
+    "ldp x20, x19, [x12, #0x30]\n"
     "usubl v5.8h, v5.8b, v13.8b\n"
-    "ldp x26, x25, [x8, #0x0]\n"
     "usubl v6.8h, v6.8b, v13.8b\n"
-    "ldp x24, x23, [x8, #0x10]\n"
+    "ldr d31, [x26, x15]\n"
+    "ldr d30, [x25, x15]\n"
     "usubl v7.8h, v7.8b, v13.8b\n"
     "usubl v8.8h, v8.8b, v13.8b\n"
-    "ldp x22, x21, [x8, #0x20]\n"
-    "ldp x20, x19, [x8, #0x30]\n"
-    "ldr d31, [x26, x5]\n"
+    "ldr d29, [x24, x15]\n"
+    "ldr d28, [x23, x15]\n"
     "usubl v31.8h, v31.8b, v12.8b\n"
-    "ldr d30, [x25, x5]\n"
-    "ldr d29, [x24, x5]\n"
     "usubl v30.8h, v30.8b, v12.8b\n"
-    "ldr d28, [x23, x5]\n"
-    "ldr d27, [x22, x5]\n"
+    "ldr d27, [x22, x15]\n"
+    "ldr d26, [x21, x15]\n"
     "usubl v29.8h, v29.8b, v12.8b\n"
-    "ldr d26, [x21, x5]\n"
     "usubl v28.8h, v28.8b, v12.8b\n"
-    "ldr d25, [x20, x5]\n"
-    "ldr d24, [x19, x5]\n"
+    "ldr d25, [x20, x15]\n"
+    "ldr d24, [x19, x15]\n"
     "usubl v27.8h, v27.8b, v12.8b\n"
     "usubl v26.8h, v26.8b, v12.8b\n"
     "usubl v25.8h, v25.8b, v12.8b\n"
@@ -176,259 +176,251 @@
     "beq 2f\n"
     "1:"  // Loop
     "smlal v15.4s, v31.4h, v8.4h\n"
-    "ldr x23, [x8, #0x40]\n"
-    "add x6, x6, #0x48\n"
     "smlal2 v10.4s, v31.8h, v8.8h\n"
-    "ldr x22, [x8, #0x48]\n"
-    "subs x16, x16, #0x1\n"
-    "smlal v20.4s, v31.4h, v6.4h\n"
-    "ldr x21, [x8, #0x50]\n"
-    "smlal2 v23.4s, v31.8h, v6.8h\n"
-    "ldr x20, [x8, #0x58]\n"
-    "smlal v16.4s, v31.4h, v2.4h\n"
-    "ldr x19, [x8, #0x60]\n"
-    "smlal2 v22.4s, v31.8h, v2.8h\n"
-    "ldr x10, [x8, #0x68]\n"
-    "smlal v17.4s, v31.4h, v0.4h\n"
-    "ldr x9, [x8, #0x70]\n"
-    "smlal2 v18.4s, v31.8h, v0.8h\n"
-    "ldr x28, [x8, #0x78]\n"
+    "ldr x24, [x12, #0x40]\n"
+    "ldr x23, [x12, #0x48]\n"
+    "smlal v9.4s, v31.4h, v6.4h\n"
+    "smlal2 v16.4s, v31.8h, v6.8h\n"
+    "ldr x21, [x12, #0x50]\n"
+    "ldr x19, [x12, #0x58]\n"
     "smlal v15.4s, v30.4h, v0.4h\n"
-    "ldr x27, [x8, #0x80]\n"
     "smlal2 v10.4s, v30.8h, v0.8h\n"
-    "ldr x26, [x8, #0x88]\n"
-    "smlal v20.4s, v28.4h, v1.4h\n"
-    "ldr x25, [x8, #0x90]\n"
-    "smlal2 v23.4s, v28.8h, v1.8h\n"
-    "ldr d28, [x22, x5]\n"
+    "ldr x22, [x12, #0x78]\n"
+    "ldr x20, [x12, #0x60]\n"
+    "smlal v9.4s, v28.4h, v1.4h\n"
+    "smlal2 v16.4s, v28.8h, v1.8h\n"
+    "ldr d28, [x23, x15]\n"
     "usubl v28.8h, v28.8b, v12.8b\n"
     "smlal v15.4s, v29.4h, v1.4h\n"
-    "ldr x24, [x8, #0x98]\n"
     "smlal2 v10.4s, v29.8h, v1.8h\n"
-    "ldr d29, [x23, x5]\n"
+    "ldr d29, [x24, x15]\n"
     "usubl v29.8h, v29.8b, v12.8b\n"
-    "smlal v20.4s, v27.4h, v2.4h\n"
-    "ldr x23, [x8, #0xa0]\n"
-    "smlal2 v23.4s, v27.8h, v2.8h\n"
-    "ldr d27, [x21, x5]\n"
+    "smlal v9.4s, v27.4h, v2.4h\n"
+    "smlal2 v16.4s, v27.8h, v2.8h\n"
+    "ldr d27, [x21, x15]\n"
     "usubl v27.8h, v27.8b, v12.8b\n"
     "smlal v15.4s, v26.4h, v3.4h\n"
-    "ldr x22, [x8, #0xa8]\n"
     "smlal2 v10.4s, v26.8h, v3.8h\n"
-    "ldr d26, [x20, x5]\n"
+    "ldr d26, [x19, x15]\n"
     "usubl v26.8h, v26.8b, v12.8b\n"
+    "smlal v9.4s, v24.4h, v0.4h\n"
+    "smlal2 v16.4s, v24.8h, v0.8h\n"
+    "ldr x21, [x12, #0x80]\n"
+    "ldr x19, [x12, #0x68]\n"
     "smlal v15.4s, v25.4h, v4.4h\n"
-    "ldr x21, [x8, #0xb0]\n"
     "smlal2 v10.4s, v25.8h, v4.8h\n"
-    "ldr d25, [x19, x5]\n"
+    "ldr d25, [x20, x15]\n"
     "usubl v25.8h, v25.8b, v12.8b\n"
+    "smlal v9.4s, v29.4h, v4.4h\n"
+    "smlal2 v16.4s, v29.8h, v4.8h\n"
+    "ldr x20, [x12, #0x88]\n"
+    "ldr d29, [x19, x15]\n"
     "smlal v15.4s, v24.4h, v2.4h\n"
-    "ldr x20, [x8, #0xb8]\n"
     "smlal2 v10.4s, v24.8h, v2.8h\n"
-    "ldr x19, [x8, #0xc0]\n"
-    "smlal v20.4s, v24.4h, v0.4h\n"
-    "ldr q21, [x17, #0x0]\n"
-    "smlal2 v23.4s, v24.8h, v0.8h\n"
-    "ldr d24, [x9, x5]\n"
-    "usubl v24.8h, v24.8b, v12.8b\n"
-    "smlal v20.4s, v29.4h, v4.4h\n"
-    "ldr q30, [x15, #0x0]\n"
-    "smlal2 v23.4s, v29.8h, v4.8h\n"
-    "ldr d29, [x10, x5]\n"
+    "ldr x19, [x12, #0x70]\n"
     "usubl v29.8h, v29.8b, v12.8b\n"
-    "smlal v20.4s, v28.4h, v5.4h\n"
-    "ldr q31, [x17, #0x10]\n"
-    "smlal2 v23.4s, v28.8h, v5.8h\n"
-    "ldr d28, [x27, x5]\n"
-    "add x17, x17, #0x20\n"
-    "smlal v15.4s, v27.4h, v5.4h\n"
-    "ldr q9, [x15, #0x10]\n"
-    "add x15, x15, #0x20\n"
-    "smlal2 v10.4s, v27.8h, v5.8h\n"
+    "smlal v9.4s, v28.4h, v5.4h\n"
+    "smlal2 v16.4s, v28.8h, v5.8h\n"
+    "ldr d28, [x21, x15]\n"
     "usubl v28.8h, v28.8b, v12.8b\n"
-    "smlal v20.4s, v27.4h, v3.4h\n"
-    "smlal2 v23.4s, v27.8h, v3.8h\n"
-    "ldr d27, [x28, x5]\n"
-    "usubl v27.8h, v27.8b, v12.8b\n"
-    "smlal v16.4s, v26.4h, v3.4h\n"
-    "smlal2 v22.4s, v26.8h, v3.8h\n"
-    "ldr d26, [x26, x5]\n"
-    "usubl v26.8h, v26.8b, v12.8b\n"
-    "smlal v15.4s, v25.4h, v6.4h\n"
-    "smlal2 v10.4s, v25.8h, v6.8h\n"
-    "smlal v16.4s, v25.4h, v0.4h\n"
-    "smlal2 v22.4s, v25.8h, v0.8h\n"
-    "ldr d25, [x25, x5]\n"
-    "usubl v25.8h, v25.8b, v12.8b\n"
-    "smlal v16.4s, v29.4h, v4.4h\n"
-    "smlal2 v22.4s, v29.8h, v4.8h\n"
-    "ldr d29, [x24, x5]\n"
-    "usubl v29.8h, v29.8b, v12.8b\n"
-    "smlal v15.4s, v24.4h, v7.4h\n"
-    "smlal2 v10.4s, v24.8h, v7.8h\n"
-    "smlal v16.4s, v24.4h, v1.4h\n"
-    "smlal2 v22.4s, v24.8h, v1.8h\n"
-    "ldr d24, [x22, x5]\n"
+    "smlal v22.4s, v31.4h, v2.4h\n"
+    "smlal2 v21.4s, v31.8h, v2.8h\n"
+    "ldr x24, [x12, #0x98]\n"
+    "ldr d24, [x19, x15]\n"
+    "smlal v15.4s, v27.4h, v5.4h\n"
+    "smlal2 v10.4s, v27.8h, v5.8h\n"
     "usubl v24.8h, v24.8b, v12.8b\n"
-    "smlal v17.4s, v27.4h, v4.4h\n"
-    "smlal2 v18.4s, v27.8h, v4.8h\n"
-    "ldr d27, [x23, x5]\n"
+    "ldr x23, [x12, #0x90]\n"
+    "smlal v9.4s, v27.4h, v3.4h\n"
+    "smlal2 v16.4s, v27.8h, v3.8h\n"
+    "ldr d27, [x22, x15]\n"
     "usubl v27.8h, v27.8b, v12.8b\n"
-    "smlal v20.4s, v28.4h, v7.4h\n"
-    "smlal2 v23.4s, v28.8h, v7.8h\n"
-    "smlal v17.4s, v28.4h, v1.4h\n"
-    "smlal2 v18.4s, v28.8h, v1.8h\n"
-    "smlal v16.4s, v25.4h, v6.4h\n"
-    "smlal2 v22.4s, v25.8h, v6.8h\n"
-    "ldr d25, [x20, x5]\n"
-    "usubl v25.8h, v25.8b, v12.8b\n"
-    "smlal v17.4s, v26.4h, v5.4h\n"
-    "smlal2 v18.4s, v26.8h, v5.8h\n"
-    "ldr d26, [x21, x5]\n"
+    "smlal v23.4s, v31.4h, v0.4h\n"
+    "smlal v22.4s, v26.4h, v3.4h\n"
+    "ldr x22, [x12, #0xa8]\n"
+    "ldr x19, [x12, #0xa0]\n"
+    "smlal2 v21.4s, v26.8h, v3.8h\n"
+    "smlal2 v18.4s, v31.8h, v0.8h\n"
+    "ldr d26, [x20, x15]\n"
     "usubl v26.8h, v26.8b, v12.8b\n"
-    "smlal v20.4s, v29.4h, v8.4h\n"
-    "smlal2 v23.4s, v29.8h, v8.8h\n"
-    "smlal v17.4s, v29.4h, v2.4h\n"
-    "smlal2 v18.4s, v29.8h, v2.8h\n"
-    "ldr d29, [x19, x5]\n"
-    "add x5, x5, #0x8\n"
-    "smlal v16.4s, v27.4h, v7.4h\n"
+    "smlal v23.4s, v27.4h, v4.4h\n"
+    "smlal v22.4s, v25.4h, v0.4h\n"
+    "ldr x21, [x12, #0xb0]\n"
+    "ldr x20, [x12, #0xb8]\n"
+    "smlal2 v21.4s, v25.8h, v0.8h\n"
+    "smlal2 v18.4s, v27.8h, v4.8h\n"
+    "ldr d27, [x19, x15]\n"
+    "usubl v27.8h, v27.8b, v12.8b\n"
+    "smlal v23.4s, v28.4h, v1.4h\n"
+    "smlal v15.4s, v25.4h, v6.4h\n"
+    "ldr x19, [x12, #0xc0]\n"
+    "ldr q19, [x13, #0x0]\n"
+    "smlal2 v10.4s, v25.8h, v6.8h\n"
+    "smlal v22.4s, v29.4h, v4.4h\n"
+    "ldr d25, [x23, x15]\n"
+    "usubl v25.8h, v25.8b, v12.8b\n"
+    "smlal2 v21.4s, v29.8h, v4.8h\n"
+    "ldr d29, [x24, x15]\n"
+    "smlal2 v18.4s, v28.8h, v1.8h\n"
     "usubl v29.8h, v29.8b, v12.8b\n"
-    "smlal2 v22.4s, v27.8h, v7.8h\n"
-    "smlal v17.4s, v24.4h, v3.4h\n"
-    "smlal v16.4s, v24.4h, v5.4h\n"
+    "smlal v23.4s, v26.4h, v5.4h\n"
+    "smlal v15.4s, v24.4h, v7.4h\n"
+    "ldr q0, [x11, #0x0]\n"
+    "ldr q4, [x13, #0x10]\n"
+    "smlal2 v10.4s, v24.8h, v7.8h\n"
+    "smlal v22.4s, v24.4h, v1.4h\n"
+    "sqdmulh v15.4s, v15.4s, v19.4s\n"
+    "ldr q31, [x11, #0x10]\n"
+    "smlal2 v21.4s, v24.8h, v1.8h\n"
+    "ldr d24, [x22, x15]\n"
+    "smlal2 v18.4s, v26.8h, v5.8h\n"
+    "usubl v24.8h, v24.8b, v12.8b\n"
+    "smlal v23.4s, v29.4h, v2.4h\n"
+    "ldr d26, [x21, x15]\n"
+    "smlal2 v18.4s, v29.8h, v2.8h\n"
+    "usubl v26.8h, v26.8b, v12.8b\n"
+    "smlal v22.4s, v25.4h, v6.4h\n"
+    "smlal v23.4s, v24.4h, v3.4h\n"
+    "and v30.16b, v15.16b, v0.16b\n"
+    "add x17, x17, #0x48\n"
+    "smlal v9.4s, v28.4h, v7.4h\n"
+    "smlal2 v16.4s, v28.8h, v7.8h\n"
+    "sqdmulh v10.4s, v10.4s, v4.4s\n"
+    "subs x16, x16, #0x1\n"
+    "smlal2 v21.4s, v25.8h, v6.8h\n"
+    "ldr d25, [x20, x15]\n"
     "smlal2 v18.4s, v24.8h, v3.8h\n"
-    "sqrdmulh v15.4s, v15.4s, v21.4s\n"
-    "smlal2 v22.4s, v24.8h, v5.8h\n"
-    "smlal v17.4s, v26.4h, v7.4h\n"
+    "usubl v25.8h, v25.8b, v12.8b\n"
+    "smlal v22.4s, v27.4h, v7.4h\n"
+    "smlal v23.4s, v26.4h, v7.4h\n"
+    "sshr v30.4s, v30.4s, #0x1f\n"
+    "add x13, x13, #0x20\n"
+    "smlal v9.4s, v29.4h, v8.4h\n"
+    "smlal2 v16.4s, v29.8h, v8.8h\n"
+    "ldr d29, [x19, x15]\n"
+    "usubl v29.8h, v29.8b, v12.8b\n"
+    "smlal2 v21.4s, v27.8h, v7.8h\n"
     "smlal2 v18.4s, v26.8h, v7.8h\n"
-    "smlal v16.4s, v25.4h, v8.4h\n"
-    "smlal2 v22.4s, v25.8h, v8.8h\n"
-    "smlal v17.4s, v25.4h, v6.4h\n"
+    "sqdmulh v9.4s, v9.4s, v19.4s\n"
+    "add x15, x15, #0x8\n"
+    "smlal v22.4s, v24.4h, v5.4h\n"
+    "smlal v23.4s, v25.4h, v6.4h\n"
+    "and v28.16b, v9.16b, v0.16b\n"
+    "add x11, x11, #0x20\n"
+    "smlal2 v21.4s, v24.8h, v5.8h\n"
     "smlal2 v18.4s, v25.8h, v6.8h\n"
-    "and v26.16b, v15.16b, v30.16b\n"
-    "sshr v26.4s, v26.4s, #0x1f\n"
-    "smlal v17.4s, v29.4h, v8.4h\n"
+    "sqdmulh v16.4s, v16.4s, v4.4s\n"
+    "smlal v22.4s, v25.4h, v8.4h\n"
+    "smlal v23.4s, v29.4h, v8.4h\n"
+    "sqdmulh v22.4s, v22.4s, v19.4s\n"
+    "smlal2 v21.4s, v25.8h, v8.8h\n"
     "smlal2 v18.4s, v29.8h, v8.8h\n"
-    "sqrdmulh v10.4s, v10.4s, v31.4s\n"
-    "sqrdmulh v20.4s, v20.4s, v21.4s\n"
-    "sqrdmulh v23.4s, v23.4s, v31.4s\n"
-    "sqrdmulh v16.4s, v16.4s, v21.4s\n"
-    "sqadd v15.4s, v15.4s, v26.4s\n"
-    "and v8.16b, v10.16b, v9.16b\n"
-    "sshr v8.4s, v8.4s, #0x1f\n"
-    "srshl v15.4s, v15.4s, v30.4s\n"
-    "and v4.16b, v20.16b, v30.16b\n"
+    "sqdmulh v23.4s, v23.4s, v19.4s\n"
+    "and v29.16b, v22.16b, v0.16b\n"
+    "sqdmulh v21.4s, v21.4s, v4.4s\n"
+    "and v20.16b, v23.16b, v0.16b\n"
+    "sqdmulh v18.4s, v18.4s, v4.4s\n"
+    "and v19.16b, v10.16b, v31.16b\n"
+    "sshr v28.4s, v28.4s, #0x1f\n"
+    "and v4.16b, v16.16b, v31.16b\n"
+    "sshr v29.4s, v29.4s, #0x1f\n"
+    "and v5.16b, v21.16b, v31.16b\n"
+    "sshr v20.4s, v20.4s, #0x1f\n"
+    "and v26.16b, v18.16b, v31.16b\n"
+    "sqadd v15.4s, v15.4s, v30.4s\n"
+    "sshr v19.4s, v19.4s, #0x1f\n"
+    "sqadd v9.4s, v9.4s, v28.4s\n"
     "sshr v4.4s, v4.4s, #0x1f\n"
-    "and v2.16b, v23.16b, v9.16b\n"
-    "and v1.16b, v16.16b, v30.16b\n"
-    "sshr v2.4s, v2.4s, #0x1f\n"
-    "add v15.4s, v15.4s, v11.4s\n"
-    "sqadd v10.4s, v10.4s, v8.4s\n"
-    "sshr v1.4s, v1.4s, #0x1f\n"
-    "sqrdmulh v22.4s, v22.4s, v31.4s\n"
-    "sqadd v20.4s, v20.4s, v4.4s\n"
-    "smin v15.4s, v15.4s, v14.4s\n"
-    "srshl v10.4s, v10.4s, v9.4s\n"
-    "sqadd v23.4s, v23.4s, v2.4s\n"
-    "smax v15.4s, v15.4s, v19.4s\n"
-    "srshl v20.4s, v20.4s, v30.4s\n"
-    "add v10.4s, v10.4s, v11.4s\n"
-    "srshl v23.4s, v23.4s, v9.4s\n"
-    "sqadd v16.4s, v16.4s, v1.4s\n"
-    "smin v10.4s, v10.4s, v14.4s\n"
-    "add v20.4s, v20.4s, v11.4s\n"
-    "add v23.4s, v23.4s, v11.4s\n"
-    "smax v10.4s, v10.4s, v19.4s\n"
-    "smin v20.4s, v20.4s, v14.4s\n"
-    "smin v23.4s, v23.4s, v14.4s\n"
-    "uzp1 v15.16b, v15.16b, v10.16b\n"
-    "smax v20.4s, v20.4s, v19.4s\n"
+    "sqadd v22.4s, v22.4s, v29.4s\n"
+    "sshr v5.4s, v5.4s, #0x1f\n"
+    "sqadd v23.4s, v23.4s, v20.4s\n"
+    "sshr v26.4s, v26.4s, #0x1f\n"
+    "srshl v15.4s, v15.4s, v0.4s\n"
+    "sqadd v10.4s, v10.4s, v19.4s\n"
+    "srshl v9.4s, v9.4s, v0.4s\n"
+    "sqadd v16.4s, v16.4s, v4.4s\n"
+    "srshl v22.4s, v22.4s, v0.4s\n"
+    "sqadd v21.4s, v21.4s, v5.4s\n"
+    "srshl v23.4s, v23.4s, v0.4s\n"
+    "sqadd v18.4s, v18.4s, v26.4s\n"
+    "srshl v10.4s, v10.4s, v31.4s\n"
+    "sqxtn v15.4h, v15.4s\n"
+    "srshl v16.4s, v16.4s, v31.4s\n"
+    "sqxtn v9.4h, v9.4s\n"
+    "srshl v21.4s, v21.4s, v31.4s\n"
+    "sqxtn v22.4h, v22.4s\n"
+    "srshl v18.4s, v18.4s, v31.4s\n"
+    "sqxtn v23.4h, v23.4s\n"
+    "sqxtn2 v15.8h, v10.4s\n"
+    "sqxtn2 v9.8h, v16.4s\n"
+    "sqxtn2 v22.8h, v21.4s\n"
+    "sqxtn2 v23.8h, v18.4s\n"
+    "sqadd v15.8h, v15.8h, v11.8h\n"
+    "sqadd v9.8h, v9.8h, v11.8h\n"
+    "sqadd v22.8h, v22.8h, v11.8h\n"
+    "sqadd v23.8h, v23.8h, v11.8h\n"
+    "smax v15.8h, v15.8h, v17.8h\n"
+    "smax v9.8h, v9.8h, v17.8h\n"
+    "smax v22.8h, v22.8h, v17.8h\n"
+    "smax v23.8h, v23.8h, v17.8h\n"
+    "smin v15.8h, v15.8h, v14.8h\n"
+    "smin v9.8h, v9.8h, v14.8h\n"
+    "smin v22.8h, v22.8h, v14.8h\n"
+    "smin v23.8h, v23.8h, v14.8h\n"
     "uzp1 v15.16b, v15.16b, v15.16b\n"
-    "str d15, [x14, x7]\n"
-    "smax v23.4s, v23.4s, v19.4s\n"
-    "srshl v16.4s, v16.4s, v30.4s\n"
-    "and v24.16b, v22.16b, v9.16b\n"
-    "sshr v24.4s, v24.4s, #0x1f\n"
-    "uzp1 v20.16b, v20.16b, v23.16b\n"
-    "add v16.4s, v16.4s, v11.4s\n"
-    "sqrdmulh v17.4s, v17.4s, v21.4s\n"
-    "uzp1 v20.16b, v20.16b, v20.16b\n"
-    "str d20, [x13, x7]\n"
-    "smin v16.4s, v16.4s, v14.4s\n"
-    "sqrdmulh v18.4s, v18.4s, v31.4s\n"
-    "sqadd v22.4s, v22.4s, v24.4s\n"
-    "and v2.16b, v17.16b, v30.16b\n"
-    "sshr v2.4s, v2.4s, #0x1f\n"
-    "smax v16.4s, v16.4s, v19.4s\n"
-    "srshl v22.4s, v22.4s, v9.4s\n"
-    "and v31.16b, v18.16b, v9.16b\n"
-    "sshr v31.4s, v31.4s, #0x1f\n"
-    "add v22.4s, v22.4s, v11.4s\n"
-    "sqadd v17.4s, v17.4s, v2.4s\n"
-    "smin v22.4s, v22.4s, v14.4s\n"
-    "srshl v17.4s, v17.4s, v30.4s\n"
-    "sqadd v18.4s, v18.4s, v31.4s\n"
-    "smax v22.4s, v22.4s, v19.4s\n"
-    "uzp1 v16.16b, v16.16b, v22.16b\n"
-    "add v17.4s, v17.4s, v11.4s\n"
-    "srshl v18.4s, v18.4s, v9.4s\n"
-    "uzp1 v16.16b, v16.16b, v16.16b\n"
-    "str d16, [x12, x7]\n"
-    "smin v17.4s, v17.4s, v14.4s\n"
-    "add v18.4s, v18.4s, v11.4s\n"
-    "smax v17.4s, v17.4s, v19.4s\n"
-    "smin v18.4s, v18.4s, v14.4s\n"
-    "smax v18.4s, v18.4s, v19.4s\n"
-    "uzp1 v17.16b, v17.16b, v18.16b\n"
-    "uzp1 v17.16b, v17.16b, v17.16b\n"
-    "str d17, [x11, x7]\n"
-    "add x7, x7, #0x8\n"
+    "str d15, [x10, x14]\n"
+    "uzp1 v9.16b, v9.16b, v9.16b\n"
+    "uzp1 v22.16b, v22.16b, v22.16b\n"
+    "str d9, [x9, x14]\n"
+    "uzp1 v23.16b, v23.16b, v23.16b\n"
+    "str d22, [x28, x14]\n"
+    "str d23, [x27, x14]\n"
     "ldr x19, [%x[params], %[offsetof_Params_bias]]\n"
     "ldr q15, [x19, #0x0]\n"
-    "mov v20.16b, v15.16b\n"
+    "add x14, x14, #0x8\n"
     "ldr q10, [x19, #0x10]\n"
     "add x19, x19, #0x20\n"
-    "mov v16.16b, v15.16b\n"
     "str x19, [%x[params], %[offsetof_Params_bias]]\n"
-    "mov v17.16b, v15.16b\n"
-    "ldr d0, [x6, #0x0]\n"
-    "usubl v0.8h, v0.8b, v13.8b\n"
-    "mov v23.16b, v10.16b\n"
-    "ldr d1, [x6, #0x8]\n"
-    "mov v22.16b, v10.16b\n"
-    "ldr d2, [x6, #0x10]\n"
-    "usubl v1.8h, v1.8b, v13.8b\n"
+    "ldr d0, [x17, #0x0]\n"
+    "ldr d1, [x17, #0x8]\n"
+    "ldr d2, [x17, #0x10]\n"
+    "mov v9.16b, v15.16b\n"
+    "mov v16.16b, v10.16b\n"
+    "ldr d3, [x17, #0x18]\n"
+    "ldr d4, [x17, #0x20]\n"
+    "mov v22.16b, v15.16b\n"
+    "mov v21.16b, v10.16b\n"
+    "ldr d5, [x17, #0x28]\n"
+    "ldr d6, [x17, #0x30]\n"
+    "mov v23.16b, v15.16b\n"
     "mov v18.16b, v10.16b\n"
-    "ldr d3, [x6, #0x18]\n"
-    "ldr d4, [x6, #0x20]\n"
+    "ldr d7, [x17, #0x38]\n"
+    "ldr d8, [x17, #0x40]\n"
+    "usubl v0.8h, v0.8b, v13.8b\n"
+    "usubl v1.8h, v1.8b, v13.8b\n"
+    "ldp x26, x25, [x12, #0x0]\n"
+    "ldp x24, x23, [x12, #0x10]\n"
     "usubl v2.8h, v2.8b, v13.8b\n"
-    "ldr d5, [x6, #0x28]\n"
     "usubl v3.8h, v3.8b, v13.8b\n"
-    "ldr d6, [x6, #0x30]\n"
-    "ldr d7, [x6, #0x38]\n"
+    "ldp x22, x21, [x12, #0x20]\n"
+    "ldp x20, x19, [x12, #0x30]\n"
     "usubl v4.8h, v4.8b, v13.8b\n"
-    "ldr d8, [x6, #0x40]\n"
     "usubl v5.8h, v5.8b, v13.8b\n"
-    "ldp x26, x25, [x8, #0x0]\n"
+    "ldr d31, [x26, x15]\n"
+    "ldr d30, [x25, x15]\n"
     "usubl v6.8h, v6.8b, v13.8b\n"
-    "ldp x24, x23, [x8, #0x10]\n"
     "usubl v7.8h, v7.8b, v13.8b\n"
+    "ldr d29, [x24, x15]\n"
+    "ldr d28, [x23, x15]\n"
     "usubl v8.8h, v8.8b, v13.8b\n"
-    "ldp x22, x21, [x8, #0x20]\n"
-    "ldp x20, x19, [x8, #0x30]\n"
-    "ldr d31, [x26, x5]\n"
     "usubl v31.8h, v31.8b, v12.8b\n"
-    "ldr d30, [x25, x5]\n"
-    "ldr d29, [x24, x5]\n"
+    "ldr d27, [x22, x15]\n"
+    "ldr d26, [x21, x15]\n"
     "usubl v30.8h, v30.8b, v12.8b\n"
-    "ldr d28, [x23, x5]\n"
-    "ldr d27, [x22, x5]\n"
     "usubl v29.8h, v29.8b, v12.8b\n"
-    "ldr d26, [x21, x5]\n"
+    "ldr d25, [x20, x15]\n"
+    "ldr d24, [x19, x15]\n"
     "usubl v28.8h, v28.8b, v12.8b\n"
-    "ldr d25, [x20, x5]\n"
-    "ldr d24, [x19, x5]\n"
     "usubl v27.8h, v27.8b, v12.8b\n"
     "usubl v26.8h, v26.8b, v12.8b\n"
     "usubl v25.8h, v25.8b, v12.8b\n"
@@ -436,275 +428,267 @@
     "bgt 1b\n"
     "2:"  // Tail
     "smlal v15.4s, v31.4h, v8.4h\n"
-    "ldr x23, [x8, #0x40]\n"
-    "tst x4, #0x7\n"
     "smlal2 v10.4s, v31.8h, v8.8h\n"
-    "ldr x22, [x8, #0x48]\n"
-    "smlal v20.4s, v31.4h, v6.4h\n"
-    "ldr x21, [x8, #0x50]\n"
-    "smlal2 v23.4s, v31.8h, v6.8h\n"
-    "ldr x20, [x8, #0x58]\n"
-    "smlal v16.4s, v31.4h, v2.4h\n"
-    "ldr x19, [x8, #0x60]\n"
-    "smlal2 v22.4s, v31.8h, v2.8h\n"
-    "ldr x10, [x8, #0x68]\n"
-    "smlal v17.4s, v31.4h, v0.4h\n"
-    "ldr x9, [x8, #0x70]\n"
-    "smlal2 v18.4s, v31.8h, v0.8h\n"
-    "ldr x28, [x8, #0x78]\n"
+    "ldr x24, [x12, #0x40]\n"
+    "ldr x23, [x12, #0x48]\n"
+    "smlal v9.4s, v31.4h, v6.4h\n"
+    "smlal2 v16.4s, v31.8h, v6.8h\n"
+    "ldr x21, [x12, #0x50]\n"
+    "ldr x19, [x12, #0x58]\n"
     "smlal v15.4s, v30.4h, v0.4h\n"
-    "ldr x27, [x8, #0x80]\n"
     "smlal2 v10.4s, v30.8h, v0.8h\n"
-    "ldr x26, [x8, #0x88]\n"
-    "smlal v20.4s, v28.4h, v1.4h\n"
-    "ldr x25, [x8, #0x90]\n"
-    "smlal2 v23.4s, v28.8h, v1.8h\n"
-    "ldr d28, [x22, x5]\n"
+    "ldr x22, [x12, #0x78]\n"
+    "ldr x20, [x12, #0x60]\n"
+    "smlal v9.4s, v28.4h, v1.4h\n"
+    "smlal2 v16.4s, v28.8h, v1.8h\n"
+    "ldr d28, [x23, x15]\n"
     "usubl v28.8h, v28.8b, v12.8b\n"
     "smlal v15.4s, v29.4h, v1.4h\n"
-    "ldr x24, [x8, #0x98]\n"
     "smlal2 v10.4s, v29.8h, v1.8h\n"
-    "ldr d29, [x23, x5]\n"
+    "ldr d29, [x24, x15]\n"
     "usubl v29.8h, v29.8b, v12.8b\n"
-    "smlal v20.4s, v27.4h, v2.4h\n"
-    "ldr x23, [x8, #0xa0]\n"
-    "smlal2 v23.4s, v27.8h, v2.8h\n"
-    "ldr d27, [x21, x5]\n"
+    "smlal v9.4s, v27.4h, v2.4h\n"
+    "smlal2 v16.4s, v27.8h, v2.8h\n"
+    "ldr d27, [x21, x15]\n"
     "usubl v27.8h, v27.8b, v12.8b\n"
     "smlal v15.4s, v26.4h, v3.4h\n"
-    "ldr x22, [x8, #0xa8]\n"
     "smlal2 v10.4s, v26.8h, v3.8h\n"
-    "ldr d26, [x20, x5]\n"
+    "ldr d26, [x19, x15]\n"
     "usubl v26.8h, v26.8b, v12.8b\n"
+    "smlal v9.4s, v24.4h, v0.4h\n"
+    "smlal2 v16.4s, v24.8h, v0.8h\n"
+    "ldr x21, [x12, #0x80]\n"
+    "ldr x19, [x12, #0x68]\n"
     "smlal v15.4s, v25.4h, v4.4h\n"
-    "ldr x21, [x8, #0xb0]\n"
     "smlal2 v10.4s, v25.8h, v4.8h\n"
-    "ldr d25, [x19, x5]\n"
+    "ldr d25, [x20, x15]\n"
     "usubl v25.8h, v25.8b, v12.8b\n"
+    "smlal v9.4s, v29.4h, v4.4h\n"
+    "smlal2 v16.4s, v29.8h, v4.8h\n"
+    "ldr x20, [x12, #0x88]\n"
+    "ldr d29, [x19, x15]\n"
     "smlal v15.4s, v24.4h, v2.4h\n"
-    "ldr x20, [x8, #0xb8]\n"
     "smlal2 v10.4s, v24.8h, v2.8h\n"
-    "ldr x19, [x8, #0xc0]\n"
-    "smlal v20.4s, v24.4h, v0.4h\n"
-    "ldr q21, [x17, #0x0]\n"
-    "smlal2 v23.4s, v24.8h, v0.8h\n"
-    "ldr d24, [x9, x5]\n"
-    "usubl v24.8h, v24.8b, v12.8b\n"
-    "smlal v20.4s, v29.4h, v4.4h\n"
-    "ldr q30, [x15, #0x0]\n"
-    "smlal2 v23.4s, v29.8h, v4.8h\n"
-    "ldr d29, [x10, x5]\n"
+    "ldr x19, [x12, #0x70]\n"
     "usubl v29.8h, v29.8b, v12.8b\n"
-    "smlal v20.4s, v28.4h, v5.4h\n"
-    "ldr q31, [x17, #0x10]\n"
-    "smlal2 v23.4s, v28.8h, v5.8h\n"
-    "ldr d28, [x27, x5]\n"
-    "add x17, x17, #0x20\n"
-    "smlal v15.4s, v27.4h, v5.4h\n"
-    "ldr q9, [x15, #0x10]\n"
-    "add x15, x15, #0x20\n"
-    "smlal2 v10.4s, v27.8h, v5.8h\n"
+    "smlal v9.4s, v28.4h, v5.4h\n"
+    "smlal2 v16.4s, v28.8h, v5.8h\n"
+    "ldr d28, [x21, x15]\n"
     "usubl v28.8h, v28.8b, v12.8b\n"
-    "smlal v20.4s, v27.4h, v3.4h\n"
-    "smlal2 v23.4s, v27.8h, v3.8h\n"
-    "ldr d27, [x28, x5]\n"
-    "usubl v27.8h, v27.8b, v12.8b\n"
-    "smlal v16.4s, v26.4h, v3.4h\n"
-    "smlal2 v22.4s, v26.8h, v3.8h\n"
-    "ldr d26, [x26, x5]\n"
-    "usubl v26.8h, v26.8b, v12.8b\n"
-    "smlal v15.4s, v25.4h, v6.4h\n"
-    "smlal2 v10.4s, v25.8h, v6.8h\n"
-    "smlal v16.4s, v25.4h, v0.4h\n"
-    "smlal2 v22.4s, v25.8h, v0.8h\n"
-    "ldr d25, [x25, x5]\n"
-    "usubl v25.8h, v25.8b, v12.8b\n"
-    "smlal v16.4s, v29.4h, v4.4h\n"
-    "smlal2 v22.4s, v29.8h, v4.8h\n"
-    "ldr d29, [x24, x5]\n"
-    "usubl v29.8h, v29.8b, v12.8b\n"
-    "smlal v15.4s, v24.4h, v7.4h\n"
-    "smlal2 v10.4s, v24.8h, v7.8h\n"
-    "smlal v16.4s, v24.4h, v1.4h\n"
-    "smlal2 v22.4s, v24.8h, v1.8h\n"
-    "ldr d24, [x22, x5]\n"
+    "smlal v22.4s, v31.4h, v2.4h\n"
+    "smlal2 v21.4s, v31.8h, v2.8h\n"
+    "ldr x24, [x12, #0x98]\n"
+    "ldr d24, [x19, x15]\n"
+    "smlal v15.4s, v27.4h, v5.4h\n"
+    "smlal2 v10.4s, v27.8h, v5.8h\n"
     "usubl v24.8h, v24.8b, v12.8b\n"
-    "smlal v17.4s, v27.4h, v4.4h\n"
-    "smlal2 v18.4s, v27.8h, v4.8h\n"
-    "ldr d27, [x23, x5]\n"
+    "ldr x23, [x12, #0x90]\n"
+    "smlal v9.4s, v27.4h, v3.4h\n"
+    "smlal2 v16.4s, v27.8h, v3.8h\n"
+    "ldr d27, [x22, x15]\n"
     "usubl v27.8h, v27.8b, v12.8b\n"
-    "smlal v20.4s, v28.4h, v7.4h\n"
-    "smlal2 v23.4s, v28.8h, v7.8h\n"
-    "smlal v17.4s, v28.4h, v1.4h\n"
-    "smlal2 v18.4s, v28.8h, v1.8h\n"
-    "smlal v16.4s, v25.4h, v6.4h\n"
-    "smlal2 v22.4s, v25.8h, v6.8h\n"
-    "ldr d25, [x20, x5]\n"
-    "usubl v25.8h, v25.8b, v12.8b\n"
-    "smlal v17.4s, v26.4h, v5.4h\n"
-    "smlal2 v18.4s, v26.8h, v5.8h\n"
-    "ldr d26, [x21, x5]\n"
+    "smlal v23.4s, v31.4h, v0.4h\n"
+    "smlal v22.4s, v26.4h, v3.4h\n"
+    "ldr x22, [x12, #0xa8]\n"
+    "ldr x19, [x12, #0xa0]\n"
+    "smlal2 v21.4s, v26.8h, v3.8h\n"
+    "smlal2 v18.4s, v31.8h, v0.8h\n"
+    "ldr d26, [x20, x15]\n"
     "usubl v26.8h, v26.8b, v12.8b\n"
-    "smlal v20.4s, v29.4h, v8.4h\n"
-    "smlal2 v23.4s, v29.8h, v8.8h\n"
-    "smlal v17.4s, v29.4h, v2.4h\n"
-    "smlal2 v18.4s, v29.8h, v2.8h\n"
-    "ldr d29, [x19, x5]\n"
-    "add x5, x5, #0x8\n"
-    "smlal v16.4s, v27.4h, v7.4h\n"
+    "smlal v23.4s, v27.4h, v4.4h\n"
+    "smlal v22.4s, v25.4h, v0.4h\n"
+    "ldr x21, [x12, #0xb0]\n"
+    "ldr x20, [x12, #0xb8]\n"
+    "smlal2 v21.4s, v25.8h, v0.8h\n"
+    "smlal2 v18.4s, v27.8h, v4.8h\n"
+    "ldr d27, [x19, x15]\n"
+    "usubl v27.8h, v27.8b, v12.8b\n"
+    "smlal v23.4s, v28.4h, v1.4h\n"
+    "smlal v15.4s, v25.4h, v6.4h\n"
+    "ldr x19, [x12, #0xc0]\n"
+    "ldr q19, [x13, #0x0]\n"
+    "smlal2 v10.4s, v25.8h, v6.8h\n"
+    "smlal v22.4s, v29.4h, v4.4h\n"
+    "ldr d25, [x23, x15]\n"
+    "usubl v25.8h, v25.8b, v12.8b\n"
+    "smlal2 v21.4s, v29.8h, v4.8h\n"
+    "ldr d29, [x24, x15]\n"
+    "smlal2 v18.4s, v28.8h, v1.8h\n"
     "usubl v29.8h, v29.8b, v12.8b\n"
-    "smlal2 v22.4s, v27.8h, v7.8h\n"
-    "smlal v17.4s, v24.4h, v3.4h\n"
-    "smlal v16.4s, v24.4h, v5.4h\n"
+    "smlal v23.4s, v26.4h, v5.4h\n"
+    "smlal v15.4s, v24.4h, v7.4h\n"
+    "ldr q0, [x11, #0x0]\n"
+    "ldr q4, [x13, #0x10]\n"
+    "smlal2 v10.4s, v24.8h, v7.8h\n"
+    "smlal v22.4s, v24.4h, v1.4h\n"
+    "sqdmulh v15.4s, v15.4s, v19.4s\n"
+    "ldr q31, [x11, #0x10]\n"
+    "smlal2 v21.4s, v24.8h, v1.8h\n"
+    "ldr d24, [x22, x15]\n"
+    "smlal2 v18.4s, v26.8h, v5.8h\n"
+    "usubl v24.8h, v24.8b, v12.8b\n"
+    "smlal v23.4s, v29.4h, v2.4h\n"
+    "ldr d26, [x21, x15]\n"
+    "smlal2 v18.4s, v29.8h, v2.8h\n"
+    "usubl v26.8h, v26.8b, v12.8b\n"
+    "smlal v22.4s, v25.4h, v6.4h\n"
+    "smlal v23.4s, v24.4h, v3.4h\n"
+    "and v30.16b, v15.16b, v0.16b\n"
+    "tst x8, #0x7\n"
+    "smlal v9.4s, v28.4h, v7.4h\n"
+    "smlal2 v16.4s, v28.8h, v7.8h\n"
+    "sqdmulh v10.4s, v10.4s, v4.4s\n"
+    "add x13, x13, #0x20\n"
+    "smlal2 v21.4s, v25.8h, v6.8h\n"
+    "ldr d25, [x20, x15]\n"
     "smlal2 v18.4s, v24.8h, v3.8h\n"
-    "sqrdmulh v15.4s, v15.4s, v21.4s\n"
-    "smlal2 v22.4s, v24.8h, v5.8h\n"
-    "smlal v17.4s, v26.4h, v7.4h\n"
+    "usubl v25.8h, v25.8b, v12.8b\n"
+    "smlal v22.4s, v27.4h, v7.4h\n"
+    "smlal v23.4s, v26.4h, v7.4h\n"
+    "sshr v30.4s, v30.4s, #0x1f\n"
+    "add x11, x11, #0x20\n"
+    "smlal v9.4s, v29.4h, v8.4h\n"
+    "smlal2 v16.4s, v29.8h, v8.8h\n"
+    "ldr d29, [x19, x15]\n"
+    "usubl v29.8h, v29.8b, v12.8b\n"
+    "smlal2 v21.4s, v27.8h, v7.8h\n"
     "smlal2 v18.4s, v26.8h, v7.8h\n"
-    "smlal v16.4s, v25.4h, v8.4h\n"
-    "smlal2 v22.4s, v25.8h, v8.8h\n"
-    "smlal v17.4s, v25.4h, v6.4h\n"
+    "sqdmulh v9.4s, v9.4s, v19.4s\n"
+    "add x15, x15, #0x8\n"
+    "smlal v22.4s, v24.4h, v5.4h\n"
+    "smlal v23.4s, v25.4h, v6.4h\n"
+    "and v28.16b, v9.16b, v0.16b\n"
+    "smlal2 v21.4s, v24.8h, v5.8h\n"
     "smlal2 v18.4s, v25.8h, v6.8h\n"
-    "and v26.16b, v15.16b, v30.16b\n"
-    "sshr v26.4s, v26.4s, #0x1f\n"
-    "smlal v17.4s, v29.4h, v8.4h\n"
+    "sqdmulh v16.4s, v16.4s, v4.4s\n"
+    "smlal v22.4s, v25.4h, v8.4h\n"
+    "smlal v23.4s, v29.4h, v8.4h\n"
+    "sqdmulh v22.4s, v22.4s, v19.4s\n"
+    "smlal2 v21.4s, v25.8h, v8.8h\n"
     "smlal2 v18.4s, v29.8h, v8.8h\n"
-    "sqrdmulh v10.4s, v10.4s, v31.4s\n"
-    "sqrdmulh v20.4s, v20.4s, v21.4s\n"
-    "sqrdmulh v23.4s, v23.4s, v31.4s\n"
-    "sqrdmulh v16.4s, v16.4s, v21.4s\n"
-    "sqadd v15.4s, v15.4s, v26.4s\n"
-    "and v8.16b, v10.16b, v9.16b\n"
-    "sshr v8.4s, v8.4s, #0x1f\n"
-    "srshl v15.4s, v15.4s, v30.4s\n"
-    "and v4.16b, v20.16b, v30.16b\n"
+    "sqdmulh v23.4s, v23.4s, v19.4s\n"
+    "and v29.16b, v22.16b, v0.16b\n"
+    "sqdmulh v21.4s, v21.4s, v4.4s\n"
+    "and v20.16b, v23.16b, v0.16b\n"
+    "sqdmulh v18.4s, v18.4s, v4.4s\n"
+    "and v19.16b, v10.16b, v31.16b\n"
+    "sshr v28.4s, v28.4s, #0x1f\n"
+    "and v4.16b, v16.16b, v31.16b\n"
+    "sshr v29.4s, v29.4s, #0x1f\n"
+    "and v5.16b, v21.16b, v31.16b\n"
+    "sshr v20.4s, v20.4s, #0x1f\n"
+    "and v26.16b, v18.16b, v31.16b\n"
+    "sqadd v15.4s, v15.4s, v30.4s\n"
+    "sshr v19.4s, v19.4s, #0x1f\n"
+    "sqadd v9.4s, v9.4s, v28.4s\n"
     "sshr v4.4s, v4.4s, #0x1f\n"
-    "and v2.16b, v23.16b, v9.16b\n"
-    "and v1.16b, v16.16b, v30.16b\n"
-    "sshr v2.4s, v2.4s, #0x1f\n"
-    "add v15.4s, v15.4s, v11.4s\n"
-    "sqadd v10.4s, v10.4s, v8.4s\n"
-    "sshr v1.4s, v1.4s, #0x1f\n"
-    "sqrdmulh v22.4s, v22.4s, v31.4s\n"
-    "sqadd v20.4s, v20.4s, v4.4s\n"
-    "smin v15.4s, v15.4s, v14.4s\n"
-    "srshl v10.4s, v10.4s, v9.4s\n"
-    "sqadd v23.4s, v23.4s, v2.4s\n"
-    "smax v15.4s, v15.4s, v19.4s\n"
-    "srshl v20.4s, v20.4s, v30.4s\n"
-    "add v10.4s, v10.4s, v11.4s\n"
-    "srshl v23.4s, v23.4s, v9.4s\n"
-    "sqadd v16.4s, v16.4s, v1.4s\n"
-    "smin v10.4s, v10.4s, v14.4s\n"
-    "add v20.4s, v20.4s, v11.4s\n"
-    "add v23.4s, v23.4s, v11.4s\n"
-    "smax v10.4s, v10.4s, v19.4s\n"
-    "smin v20.4s, v20.4s, v14.4s\n"
-    "smin v23.4s, v23.4s, v14.4s\n"
-    "uzp1 v15.16b, v15.16b, v10.16b\n"
-    "smax v20.4s, v20.4s, v19.4s\n"
+    "sqadd v22.4s, v22.4s, v29.4s\n"
+    "sshr v5.4s, v5.4s, #0x1f\n"
+    "sqadd v23.4s, v23.4s, v20.4s\n"
+    "sshr v26.4s, v26.4s, #0x1f\n"
+    "srshl v15.4s, v15.4s, v0.4s\n"
+    "sqadd v10.4s, v10.4s, v19.4s\n"
+    "srshl v9.4s, v9.4s, v0.4s\n"
+    "sqadd v16.4s, v16.4s, v4.4s\n"
+    "srshl v22.4s, v22.4s, v0.4s\n"
+    "sqadd v21.4s, v21.4s, v5.4s\n"
+    "srshl v23.4s, v23.4s, v0.4s\n"
+    "sqadd v18.4s, v18.4s, v26.4s\n"
+    "srshl v10.4s, v10.4s, v31.4s\n"
+    "sqxtn v15.4h, v15.4s\n"
+    "srshl v16.4s, v16.4s, v31.4s\n"
+    "sqxtn v9.4h, v9.4s\n"
+    "srshl v21.4s, v21.4s, v31.4s\n"
+    "sqxtn v22.4h, v22.4s\n"
+    "srshl v18.4s, v18.4s, v31.4s\n"
+    "sqxtn v23.4h, v23.4s\n"
+    "sqxtn2 v15.8h, v10.4s\n"
+    "sqxtn2 v9.8h, v16.4s\n"
+    "sqxtn2 v22.8h, v21.4s\n"
+    "sqxtn2 v23.8h, v18.4s\n"
+    "sqadd v15.8h, v15.8h, v11.8h\n"
+    "sqadd v9.8h, v9.8h, v11.8h\n"
+    "sqadd v22.8h, v22.8h, v11.8h\n"
+    "sqadd v23.8h, v23.8h, v11.8h\n"
+    "smax v15.8h, v15.8h, v17.8h\n"
+    "smax v9.8h, v9.8h, v17.8h\n"
+    "smax v22.8h, v22.8h, v17.8h\n"
+    "smax v23.8h, v23.8h, v17.8h\n"
+    "smin v15.8h, v15.8h, v14.8h\n"
+    "smin v9.8h, v9.8h, v14.8h\n"
+    "smin v22.8h, v22.8h, v14.8h\n"
+    "smin v23.8h, v23.8h, v14.8h\n"
     "uzp1 v15.16b, v15.16b, v15.16b\n"
-    "str d15, [x14, x7]\n"
-    "smax v23.4s, v23.4s, v19.4s\n"
-    "srshl v16.4s, v16.4s, v30.4s\n"
-    "and v24.16b, v22.16b, v9.16b\n"
-    "sshr v24.4s, v24.4s, #0x1f\n"
-    "uzp1 v20.16b, v20.16b, v23.16b\n"
-    "add v16.4s, v16.4s, v11.4s\n"
-    "sqrdmulh v17.4s, v17.4s, v21.4s\n"
-    "uzp1 v20.16b, v20.16b, v20.16b\n"
-    "str d20, [x13, x7]\n"
-    "smin v16.4s, v16.4s, v14.4s\n"
-    "sqrdmulh v18.4s, v18.4s, v31.4s\n"
-    "sqadd v22.4s, v22.4s, v24.4s\n"
-    "and v2.16b, v17.16b, v30.16b\n"
-    "sshr v2.4s, v2.4s, #0x1f\n"
-    "smax v16.4s, v16.4s, v19.4s\n"
-    "srshl v22.4s, v22.4s, v9.4s\n"
-    "and v31.16b, v18.16b, v9.16b\n"
-    "sshr v31.4s, v31.4s, #0x1f\n"
-    "add v22.4s, v22.4s, v11.4s\n"
-    "sqadd v17.4s, v17.4s, v2.4s\n"
-    "smin v22.4s, v22.4s, v14.4s\n"
-    "srshl v17.4s, v17.4s, v30.4s\n"
-    "sqadd v18.4s, v18.4s, v31.4s\n"
-    "smax v22.4s, v22.4s, v19.4s\n"
-    "uzp1 v16.16b, v16.16b, v22.16b\n"
-    "add v17.4s, v17.4s, v11.4s\n"
-    "srshl v18.4s, v18.4s, v9.4s\n"
-    "uzp1 v16.16b, v16.16b, v16.16b\n"
-    "str d16, [x12, x7]\n"
-    "smin v17.4s, v17.4s, v14.4s\n"
-    "add v18.4s, v18.4s, v11.4s\n"
-    "smax v17.4s, v17.4s, v19.4s\n"
-    "smin v18.4s, v18.4s, v14.4s\n"
-    "smax v18.4s, v18.4s, v19.4s\n"
-    "uzp1 v17.16b, v17.16b, v18.16b\n"
-    "uzp1 v17.16b, v17.16b, v17.16b\n"
-    "str d17, [x11, x7]\n"
-    "add x7, x7, #0x8\n"
+    "str d15, [x10, x14]\n"
+    "uzp1 v9.16b, v9.16b, v9.16b\n"
+    "uzp1 v22.16b, v22.16b, v22.16b\n"
+    "str d9, [x9, x14]\n"
+    "uzp1 v23.16b, v23.16b, v23.16b\n"
+    "str d22, [x28, x14]\n"
+    "str d23, [x27, x14]\n"
+    "add x14, x14, #0x8\n"
     "beq 88f\n"
-    "add x6, x6, #0x48\n"
+    "add x17, x17, #0x48\n"
     "3:"  // Oddments
     "ldr x19, [%x[params], %[offsetof_Params_bias]]\n"
-    "tbz x4, #2, 5f\n"
+    "tbz x8, #2, 5f\n"
     "ld1 { v15.4s }, [x19], #0x10\n"
-    "tbz x4, #1, 4f\n"
+    "tbz x8, #1, 4f\n"
     "ld1 { v10.d }[0], [x19], #0x8\n"
-    "tbz x4, #0, 7f\n"
+    "tbz x8, #0, 7f\n"
     "ld1 { v10.s }[2], [x19]\n"
     "b 7f\n"
     "4:"  // Oddments: Load bias: Bit 2: Bit 1: Unset
-    "tbz x4, #0, 7f\n"
+    "tbz x8, #0, 7f\n"
     "ld1 { v10.s }[0], [x19]\n"
     "b 7f\n"
     "5:"  // Oddments: Load bias: Bit 2: Unset
-    "tbz x4, #1, 6f\n"
+    "tbz x8, #1, 6f\n"
     "ld1 { v15.d }[0], [x19], #0x8\n"
-    "tbz x4, #0, 7f\n"
+    "tbz x8, #0, 7f\n"
     "ld1 { v15.s }[2], [x19]\n"
     "b 7f\n"
     "6:"  // Oddments: Load bias: Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 7f\n"
+    "tbz x8, #0, 7f\n"
     "ld1 { v15.s }[0], [x19]\n"
     "7:"  // Oddments: Load bias: Bit 2: End
-    "mov v20.16b, v15.16b\n"
-    "ldr d0, [x6, #0x0]\n"
-    "mov v23.16b, v10.16b\n"
-    "ldr d1, [x6, #0x8]\n"
-    "mov v16.16b, v15.16b\n"
-    "ldr d2, [x6, #0x10]\n"
-    "mov v22.16b, v10.16b\n"
-    "ldr d3, [x6, #0x18]\n"
-    "mov v17.16b, v15.16b\n"
-    "ldr d4, [x6, #0x20]\n"
-    "usubl v0.8h, v0.8b, v13.8b\n"
+    "ldr d0, [x17, #0x0]\n"
+    "ldr d1, [x17, #0x8]\n"
+    "mov v9.16b, v15.16b\n"
+    "mov v16.16b, v10.16b\n"
+    "ldr d2, [x17, #0x10]\n"
+    "ldr d3, [x17, #0x18]\n"
+    "mov v22.16b, v15.16b\n"
+    "mov v21.16b, v10.16b\n"
+    "ldr d4, [x17, #0x20]\n"
+    "ldr d5, [x17, #0x28]\n"
+    "mov v23.16b, v15.16b\n"
     "mov v18.16b, v10.16b\n"
-    "ldr d5, [x6, #0x28]\n"
+    "ldr d6, [x17, #0x30]\n"
+    "ldr d7, [x17, #0x38]\n"
+    "usubl v0.8h, v0.8b, v13.8b\n"
     "usubl v1.8h, v1.8b, v13.8b\n"
-    "ldr d6, [x6, #0x30]\n"
+    "ldr d8, [x17, #0x40]\n"
+    "ldp x26, x25, [x12, #0x0]\n"
     "usubl v2.8h, v2.8b, v13.8b\n"
-    "ldr d7, [x6, #0x38]\n"
     "usubl v3.8h, v3.8b, v13.8b\n"
-    "ldr d8, [x6, #0x40]\n"
+    "ldp x24, x23, [x12, #0x10]\n"
+    "ldp x22, x21, [x12, #0x20]\n"
     "usubl v4.8h, v4.8b, v13.8b\n"
-    "ldp x26, x25, [x8, #0x0]\n"
     "usubl v5.8h, v5.8b, v13.8b\n"
-    "ldp x24, x23, [x8, #0x10]\n"
+    "ldp x20, x19, [x12, #0x30]\n"
     "usubl v6.8h, v6.8b, v13.8b\n"
     "usubl v7.8h, v7.8b, v13.8b\n"
-    "ldp x22, x21, [x8, #0x20]\n"
     "usubl v8.8h, v8.8b, v13.8b\n"
-    "ldp x20, x19, [x8, #0x30]\n"
-    "add x26, x26, x5\n"
-    "add x25, x25, x5\n"
-    "add x24, x24, x5\n"
-    "add x23, x23, x5\n"
-    "add x22, x22, x5\n"
-    "add x21, x21, x5\n"
-    "add x20, x20, x5\n"
-    "add x19, x19, x5\n"
-    "tbz x4, #2, 9f\n"
+    "add x26, x26, x15\n"
+    "add x25, x25, x15\n"
+    "add x24, x24, x15\n"
+    "add x23, x23, x15\n"
+    "add x22, x22, x15\n"
+    "add x21, x21, x15\n"
+    "add x20, x20, x15\n"
+    "add x19, x19, x15\n"
+    "tbz x8, #2, 9f\n"
     "ld1 { v31.s }[0], [x26], #0x4\n"
     "ld1 { v30.s }[0], [x25], #0x4\n"
     "ld1 { v29.s }[0], [x24], #0x4\n"
@@ -713,7 +697,7 @@
     "ld1 { v26.s }[0], [x21], #0x4\n"
     "ld1 { v25.s }[0], [x20], #0x4\n"
     "ld1 { v24.s }[0], [x19], #0x4\n"
-    "tbz x4, #1, 8f\n"
+    "tbz x8, #1, 8f\n"
     "ld1 { v31.h }[2], [x26], #0x2\n"
     "ld1 { v30.h }[2], [x25], #0x2\n"
     "ld1 { v29.h }[2], [x24], #0x2\n"
@@ -722,7 +706,7 @@
     "ld1 { v26.h }[2], [x21], #0x2\n"
     "ld1 { v25.h }[2], [x20], #0x2\n"
     "ld1 { v24.h }[2], [x19], #0x2\n"
-    "tbz x4, #0, 11f\n"
+    "tbz x8, #0, 11f\n"
     "ld1 { v31.b }[6], [x26]\n"
     "ld1 { v30.b }[6], [x25]\n"
     "ld1 { v29.b }[6], [x24]\n"
@@ -733,7 +717,7 @@
     "ld1 { v24.b }[6], [x19]\n"
     "b 11f\n"
     "8:"  // Oddments: Initial loads: Bit 2: Bit 1: Unset
-    "tbz x4, #0, 11f\n"
+    "tbz x8, #0, 11f\n"
     "ld1 { v31.b }[4], [x26]\n"
     "ld1 { v30.b }[4], [x25]\n"
     "ld1 { v29.b }[4], [x24]\n"
@@ -744,7 +728,7 @@
     "ld1 { v24.b }[4], [x19]\n"
     "b 11f\n"
     "9:"  // Oddments: Initial loads: Bit 2: Unset
-    "tbz x4, #1, 10f\n"
+    "tbz x8, #1, 10f\n"
     "ld1 { v31.h }[0], [x26], #0x2\n"
     "ld1 { v30.h }[0], [x25], #0x2\n"
     "ld1 { v29.h }[0], [x24], #0x2\n"
@@ -753,7 +737,7 @@
     "ld1 { v26.h }[0], [x21], #0x2\n"
     "ld1 { v25.h }[0], [x20], #0x2\n"
     "ld1 { v24.h }[0], [x19], #0x2\n"
-    "tbz x4, #0, 11f\n"
+    "tbz x8, #0, 11f\n"
     "ld1 { v31.b }[2], [x26]\n"
     "ld1 { v30.b }[2], [x25]\n"
     "ld1 { v29.b }[2], [x24]\n"
@@ -764,7 +748,7 @@
     "ld1 { v24.b }[2], [x19]\n"
     "b 11f\n"
     "10:"  // Oddments: Initial loads: Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 11f\n"
+    "tbz x8, #0, 11f\n"
     "ld1 { v31.b }[0], [x26]\n"
     "ld1 { v30.b }[0], [x25]\n"
     "ld1 { v29.b }[0], [x24]\n"
@@ -774,646 +758,636 @@
     "ld1 { v25.b }[0], [x20]\n"
     "ld1 { v24.b }[0], [x19]\n"
     "11:"  // Oddments: Initial loads: Bit 2: End
-    "ldr x23, [x8, #0x40]\n"
     "usubl v31.8h, v31.8b, v12.8b\n"
     "smlal v15.4s, v31.4h, v8.4h\n"
-    "usubl v30.8h, v30.8b, v12.8b\n"
     "smlal2 v10.4s, v31.8h, v8.8h\n"
-    "usubl v29.8h, v29.8b, v12.8b\n"
-    "smlal v20.4s, v31.4h, v6.4h\n"
-    "usubl v28.8h, v28.8b, v12.8b\n"
-    "smlal2 v23.4s, v31.8h, v6.8h\n"
-    "usubl v27.8h, v27.8b, v12.8b\n"
-    "smlal v16.4s, v31.4h, v2.4h\n"
-    "usubl v26.8h, v26.8b, v12.8b\n"
-    "smlal2 v22.4s, v31.8h, v2.8h\n"
-    "usubl v25.8h, v25.8b, v12.8b\n"
-    "smlal v17.4s, v31.4h, v0.4h\n"
-    "usubl v24.8h, v24.8b, v12.8b\n"
-    "smlal2 v18.4s, v31.8h, v0.8h\n"
-    "add x23, x23, x5\n"
+    "ldr x24, [x12, #0x40]\n"
+    "usubl v30.8h, v30.8b, v12.8b\n"
     "smlal v15.4s, v30.4h, v0.4h\n"
     "smlal2 v10.4s, v30.8h, v0.8h\n"
-    "smlal v20.4s, v28.4h, v1.4h\n"
-    "smlal2 v23.4s, v28.8h, v1.8h\n"
+    "add x24, x24, x15\n"
+    "usubl v29.8h, v29.8b, v12.8b\n"
+    "smlal v9.4s, v31.4h, v6.4h\n"
+    "smlal2 v16.4s, v31.8h, v6.8h\n"
     "smlal v15.4s, v29.4h, v1.4h\n"
     "smlal2 v10.4s, v29.8h, v1.8h\n"
-    "smlal v20.4s, v27.4h, v2.4h\n"
-    "smlal2 v23.4s, v27.8h, v2.8h\n"
+    "usubl v28.8h, v28.8b, v12.8b\n"
+    "usubl v26.8h, v26.8b, v12.8b\n"
+    "smlal v9.4s, v28.4h, v1.4h\n"
+    "smlal2 v16.4s, v28.8h, v1.8h\n"
     "smlal v15.4s, v26.4h, v3.4h\n"
     "smlal2 v10.4s, v26.8h, v3.8h\n"
-    "smlal v20.4s, v24.4h, v0.4h\n"
-    "smlal2 v23.4s, v24.8h, v0.8h\n"
+    "usubl v27.8h, v27.8b, v12.8b\n"
+    "usubl v25.8h, v25.8b, v12.8b\n"
+    "smlal v9.4s, v27.4h, v2.4h\n"
+    "smlal2 v16.4s, v27.8h, v2.8h\n"
     "smlal v15.4s, v25.4h, v4.4h\n"
     "smlal2 v10.4s, v25.8h, v4.8h\n"
+    "usubl v24.8h, v24.8b, v12.8b\n"
+    "smlal v22.4s, v31.4h, v2.4h\n"
+    "smlal2 v21.4s, v31.8h, v2.8h\n"
+    "smlal v23.4s, v31.4h, v0.4h\n"
+    "smlal2 v18.4s, v31.8h, v0.8h\n"
     "smlal v15.4s, v24.4h, v2.4h\n"
     "smlal2 v10.4s, v24.8h, v2.8h\n"
-    "tbz x4, #2, 13f\n"
-    "ld1 { v29.s }[0], [x23], #0x4\n"
-    "tbz x4, #1, 12f\n"
-    "ld1 { v29.h }[2], [x23], #0x2\n"
-    "tbz x4, #0, 15f\n"
-    "ld1 { v29.b }[6], [x23]\n"
+    "smlal v9.4s, v24.4h, v0.4h\n"
+    "smlal2 v16.4s, v24.8h, v0.8h\n"
+    "tbz x8, #2, 13f\n"
+    "ld1 { v29.s }[0], [x24], #0x4\n"
+    "tbz x8, #1, 12f\n"
+    "ld1 { v29.h }[2], [x24], #0x2\n"
+    "tbz x8, #0, 15f\n"
+    "ld1 { v29.b }[6], [x24]\n"
     "b 15f\n"
     "12:"  // Oddments: Load (1, 3): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 15f\n"
-    "ld1 { v29.b }[4], [x23]\n"
+    "tbz x8, #0, 15f\n"
+    "ld1 { v29.b }[4], [x24]\n"
     "b 15f\n"
     "13:"  // Oddments: Load (1, 3): Bit 2: Unset
-    "tbz x4, #1, 14f\n"
-    "ld1 { v29.h }[0], [x23], #0x2\n"
-    "tbz x4, #0, 15f\n"
-    "ld1 { v29.b }[2], [x23]\n"
+    "tbz x8, #1, 14f\n"
+    "ld1 { v29.h }[0], [x24], #0x2\n"
+    "tbz x8, #0, 15f\n"
+    "ld1 { v29.b }[2], [x24]\n"
     "b 15f\n"
     "14:"  // Oddments: Load (1, 3): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 15f\n"
-    "ld1 { v29.b }[0], [x23]\n"
+    "tbz x8, #0, 15f\n"
+    "ld1 { v29.b }[0], [x24]\n"
     "15:"  // Oddments: Load (1, 3): Bit 2: End
-    "ldr x22, [x8, #0x48]\n"
     "usubl v29.8h, v29.8b, v12.8b\n"
-    "smlal v20.4s, v29.4h, v4.4h\n"
-    "smlal2 v23.4s, v29.8h, v4.8h\n"
-    "add x22, x22, x5\n"
-    "tbz x4, #2, 17f\n"
-    "ld1 { v28.s }[0], [x22], #0x4\n"
-    "tbz x4, #1, 16f\n"
-    "ld1 { v28.h }[2], [x22], #0x2\n"
-    "tbz x4, #0, 19f\n"
-    "ld1 { v28.b }[6], [x22]\n"
+    "ldr x23, [x12, #0x48]\n"
+    "smlal v9.4s, v29.4h, v4.4h\n"
+    "smlal2 v16.4s, v29.8h, v4.8h\n"
+    "add x23, x23, x15\n"
+    "tbz x8, #2, 17f\n"
+    "ld1 { v28.s }[0], [x23], #0x4\n"
+    "tbz x8, #1, 16f\n"
+    "ld1 { v28.h }[2], [x23], #0x2\n"
+    "tbz x8, #0, 19f\n"
+    "ld1 { v28.b }[6], [x23]\n"
     "b 19f\n"
     "16:"  // Oddments: Load (1, 4): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 19f\n"
-    "ld1 { v28.b }[4], [x22]\n"
+    "tbz x8, #0, 19f\n"
+    "ld1 { v28.b }[4], [x23]\n"
     "b 19f\n"
     "17:"  // Oddments: Load (1, 4): Bit 2: Unset
-    "tbz x4, #1, 18f\n"
-    "ld1 { v28.h }[0], [x22], #0x2\n"
-    "tbz x4, #0, 19f\n"
-    "ld1 { v28.b }[2], [x22]\n"
+    "tbz x8, #1, 18f\n"
+    "ld1 { v28.h }[0], [x23], #0x2\n"
+    "tbz x8, #0, 19f\n"
+    "ld1 { v28.b }[2], [x23]\n"
     "b 19f\n"
     "18:"  // Oddments: Load (1, 4): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 19f\n"
-    "ld1 { v28.b }[0], [x22]\n"
+    "tbz x8, #0, 19f\n"
+    "ld1 { v28.b }[0], [x23]\n"
     "19:"  // Oddments: Load (1, 4): Bit 2: End
-    "ldr x21, [x8, #0x50]\n"
     "usubl v28.8h, v28.8b, v12.8b\n"
-    "smlal v20.4s, v28.4h, v5.4h\n"
-    "smlal2 v23.4s, v28.8h, v5.8h\n"
-    "add x21, x21, x5\n"
-    "tbz x4, #2, 21f\n"
+    "ldr x21, [x12, #0x50]\n"
+    "smlal v9.4s, v28.4h, v5.4h\n"
+    "smlal2 v16.4s, v28.8h, v5.8h\n"
+    "add x21, x21, x15\n"
+    "tbz x8, #2, 21f\n"
     "ld1 { v27.s }[0], [x21], #0x4\n"
-    "tbz x4, #1, 20f\n"
+    "tbz x8, #1, 20f\n"
     "ld1 { v27.h }[2], [x21], #0x2\n"
-    "tbz x4, #0, 23f\n"
+    "tbz x8, #0, 23f\n"
     "ld1 { v27.b }[6], [x21]\n"
     "b 23f\n"
     "20:"  // Oddments: Load (1, 2): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 23f\n"
+    "tbz x8, #0, 23f\n"
     "ld1 { v27.b }[4], [x21]\n"
     "b 23f\n"
     "21:"  // Oddments: Load (1, 2): Bit 2: Unset
-    "tbz x4, #1, 22f\n"
+    "tbz x8, #1, 22f\n"
     "ld1 { v27.h }[0], [x21], #0x2\n"
-    "tbz x4, #0, 23f\n"
+    "tbz x8, #0, 23f\n"
     "ld1 { v27.b }[2], [x21]\n"
     "b 23f\n"
     "22:"  // Oddments: Load (1, 2): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 23f\n"
+    "tbz x8, #0, 23f\n"
     "ld1 { v27.b }[0], [x21]\n"
     "23:"  // Oddments: Load (1, 2): Bit 2: End
-    "ldr x20, [x8, #0x58]\n"
     "usubl v27.8h, v27.8b, v12.8b\n"
+    "ldr x19, [x12, #0x58]\n"
     "smlal v15.4s, v27.4h, v5.4h\n"
     "smlal2 v10.4s, v27.8h, v5.8h\n"
-    "add x20, x20, x5\n"
-    "smlal v20.4s, v27.4h, v3.4h\n"
-    "smlal2 v23.4s, v27.8h, v3.8h\n"
-    "tbz x4, #2, 25f\n"
-    "ld1 { v26.s }[0], [x20], #0x4\n"
-    "tbz x4, #1, 24f\n"
-    "ld1 { v26.h }[2], [x20], #0x2\n"
-    "tbz x4, #0, 27f\n"
-    "ld1 { v26.b }[6], [x20]\n"
+    "smlal v9.4s, v27.4h, v3.4h\n"
+    "smlal2 v16.4s, v27.8h, v3.8h\n"
+    "add x19, x19, x15\n"
+    "tbz x8, #2, 25f\n"
+    "ld1 { v26.s }[0], [x19], #0x4\n"
+    "tbz x8, #1, 24f\n"
+    "ld1 { v26.h }[2], [x19], #0x2\n"
+    "tbz x8, #0, 27f\n"
+    "ld1 { v26.b }[6], [x19]\n"
     "b 27f\n"
     "24:"  // Oddments: Load (3, 0): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 27f\n"
-    "ld1 { v26.b }[4], [x20]\n"
+    "tbz x8, #0, 27f\n"
+    "ld1 { v26.b }[4], [x19]\n"
     "b 27f\n"
     "25:"  // Oddments: Load (3, 0): Bit 2: Unset
-    "tbz x4, #1, 26f\n"
-    "ld1 { v26.h }[0], [x20], #0x2\n"
-    "tbz x4, #0, 27f\n"
-    "ld1 { v26.b }[2], [x20]\n"
+    "tbz x8, #1, 26f\n"
+    "ld1 { v26.h }[0], [x19], #0x2\n"
+    "tbz x8, #0, 27f\n"
+    "ld1 { v26.b }[2], [x19]\n"
     "b 27f\n"
     "26:"  // Oddments: Load (3, 0): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 27f\n"
-    "ld1 { v26.b }[0], [x20]\n"
+    "tbz x8, #0, 27f\n"
+    "ld1 { v26.b }[0], [x19]\n"
     "27:"  // Oddments: Load (3, 0): Bit 2: End
-    "ldr x19, [x8, #0x60]\n"
     "usubl v26.8h, v26.8b, v12.8b\n"
-    "smlal v16.4s, v26.4h, v3.4h\n"
-    "smlal2 v22.4s, v26.8h, v3.8h\n"
-    "add x19, x19, x5\n"
-    "tbz x4, #2, 29f\n"
-    "ld1 { v25.s }[0], [x19], #0x4\n"
-    "tbz x4, #1, 28f\n"
-    "ld1 { v25.h }[2], [x19], #0x2\n"
-    "tbz x4, #0, 31f\n"
-    "ld1 { v25.b }[6], [x19]\n"
+    "ldr x20, [x12, #0x60]\n"
+    "smlal v22.4s, v26.4h, v3.4h\n"
+    "smlal2 v21.4s, v26.8h, v3.8h\n"
+    "add x20, x20, x15\n"
+    "tbz x8, #2, 29f\n"
+    "ld1 { v25.s }[0], [x20], #0x4\n"
+    "tbz x8, #1, 28f\n"
+    "ld1 { v25.h }[2], [x20], #0x2\n"
+    "tbz x8, #0, 31f\n"
+    "ld1 { v25.b }[6], [x20]\n"
     "b 31f\n"
     "28:"  // Oddments: Load (2, 0): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 31f\n"
-    "ld1 { v25.b }[4], [x19]\n"
+    "tbz x8, #0, 31f\n"
+    "ld1 { v25.b }[4], [x20]\n"
     "b 31f\n"
     "29:"  // Oddments: Load (2, 0): Bit 2: Unset
-    "tbz x4, #1, 30f\n"
-    "ld1 { v25.h }[0], [x19], #0x2\n"
-    "tbz x4, #0, 31f\n"
-    "ld1 { v25.b }[2], [x19]\n"
+    "tbz x8, #1, 30f\n"
+    "ld1 { v25.h }[0], [x20], #0x2\n"
+    "tbz x8, #0, 31f\n"
+    "ld1 { v25.b }[2], [x20]\n"
     "b 31f\n"
     "30:"  // Oddments: Load (2, 0): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 31f\n"
-    "ld1 { v25.b }[0], [x19]\n"
+    "tbz x8, #0, 31f\n"
+    "ld1 { v25.b }[0], [x20]\n"
     "31:"  // Oddments: Load (2, 0): Bit 2: End
-    "ldr x10, [x8, #0x68]\n"
     "usubl v25.8h, v25.8b, v12.8b\n"
+    "ldr x19, [x12, #0x68]\n"
     "smlal v15.4s, v25.4h, v6.4h\n"
     "smlal2 v10.4s, v25.8h, v6.8h\n"
-    "add x10, x10, x5\n"
-    "smlal v16.4s, v25.4h, v0.4h\n"
-    "smlal2 v22.4s, v25.8h, v0.8h\n"
-    "tbz x4, #2, 33f\n"
-    "ld1 { v29.s }[0], [x10], #0x4\n"
-    "tbz x4, #1, 32f\n"
-    "ld1 { v29.h }[2], [x10], #0x2\n"
-    "tbz x4, #0, 35f\n"
-    "ld1 { v29.b }[6], [x10]\n"
+    "smlal v22.4s, v25.4h, v0.4h\n"
+    "smlal2 v21.4s, v25.8h, v0.8h\n"
+    "add x19, x19, x15\n"
+    "tbz x8, #2, 33f\n"
+    "ld1 { v29.s }[0], [x19], #0x4\n"
+    "tbz x8, #1, 32f\n"
+    "ld1 { v29.h }[2], [x19], #0x2\n"
+    "tbz x8, #0, 35f\n"
+    "ld1 { v29.b }[6], [x19]\n"
     "b 35f\n"
     "32:"  // Oddments: Load (3, 1): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 35f\n"
-    "ld1 { v29.b }[4], [x10]\n"
+    "tbz x8, #0, 35f\n"
+    "ld1 { v29.b }[4], [x19]\n"
     "b 35f\n"
     "33:"  // Oddments: Load (3, 1): Bit 2: Unset
-    "tbz x4, #1, 34f\n"
-    "ld1 { v29.h }[0], [x10], #0x2\n"
-    "tbz x4, #0, 35f\n"
-    "ld1 { v29.b }[2], [x10]\n"
+    "tbz x8, #1, 34f\n"
+    "ld1 { v29.h }[0], [x19], #0x2\n"
+    "tbz x8, #0, 35f\n"
+    "ld1 { v29.b }[2], [x19]\n"
     "b 35f\n"
     "34:"  // Oddments: Load (3, 1): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 35f\n"
-    "ld1 { v29.b }[0], [x10]\n"
+    "tbz x8, #0, 35f\n"
+    "ld1 { v29.b }[0], [x19]\n"
     "35:"  // Oddments: Load (3, 1): Bit 2: End
-    "ldr x9, [x8, #0x70]\n"
     "usubl v29.8h, v29.8b, v12.8b\n"
-    "smlal v16.4s, v29.4h, v4.4h\n"
-    "smlal2 v22.4s, v29.8h, v4.8h\n"
-    "add x9, x9, x5\n"
-    "tbz x4, #2, 37f\n"
-    "ld1 { v24.s }[0], [x9], #0x4\n"
-    "tbz x4, #1, 36f\n"
-    "ld1 { v24.h }[2], [x9], #0x2\n"
-    "tbz x4, #0, 39f\n"
-    "ld1 { v24.b }[6], [x9]\n"
+    "ldr x19, [x12, #0x70]\n"
+    "smlal v22.4s, v29.4h, v4.4h\n"
+    "smlal2 v21.4s, v29.8h, v4.8h\n"
+    "add x19, x19, x15\n"
+    "tbz x8, #2, 37f\n"
+    "ld1 { v24.s }[0], [x19], #0x4\n"
+    "tbz x8, #1, 36f\n"
+    "ld1 { v24.h }[2], [x19], #0x2\n"
+    "tbz x8, #0, 39f\n"
+    "ld1 { v24.b }[6], [x19]\n"
     "b 39f\n"
     "36:"  // Oddments: Load (2, 1): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 39f\n"
-    "ld1 { v24.b }[4], [x9]\n"
+    "tbz x8, #0, 39f\n"
+    "ld1 { v24.b }[4], [x19]\n"
     "b 39f\n"
     "37:"  // Oddments: Load (2, 1): Bit 2: Unset
-    "tbz x4, #1, 38f\n"
-    "ld1 { v24.h }[0], [x9], #0x2\n"
-    "tbz x4, #0, 39f\n"
-    "ld1 { v24.b }[2], [x9]\n"
+    "tbz x8, #1, 38f\n"
+    "ld1 { v24.h }[0], [x19], #0x2\n"
+    "tbz x8, #0, 39f\n"
+    "ld1 { v24.b }[2], [x19]\n"
     "b 39f\n"
     "38:"  // Oddments: Load (2, 1): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 39f\n"
-    "ld1 { v24.b }[0], [x9]\n"
+    "tbz x8, #0, 39f\n"
+    "ld1 { v24.b }[0], [x19]\n"
     "39:"  // Oddments: Load (2, 1): Bit 2: End
-    "ldr x28, [x8, #0x78]\n"
     "usubl v24.8h, v24.8b, v12.8b\n"
+    "ldr x22, [x12, #0x78]\n"
     "smlal v15.4s, v24.4h, v7.4h\n"
     "smlal2 v10.4s, v24.8h, v7.8h\n"
-    "add x28, x28, x5\n"
-    "smlal v16.4s, v24.4h, v1.4h\n"
-    "smlal2 v22.4s, v24.8h, v1.8h\n"
-    "tbz x4, #2, 41f\n"
-    "ld1 { v27.s }[0], [x28], #0x4\n"
-    "tbz x4, #1, 40f\n"
-    "ld1 { v27.h }[2], [x28], #0x2\n"
-    "tbz x4, #0, 43f\n"
-    "ld1 { v27.b }[6], [x28]\n"
+    "smlal v22.4s, v24.4h, v1.4h\n"
+    "smlal2 v21.4s, v24.8h, v1.8h\n"
+    "add x22, x22, x15\n"
+    "tbz x8, #2, 41f\n"
+    "ld1 { v27.s }[0], [x22], #0x4\n"
+    "tbz x8, #1, 40f\n"
+    "ld1 { v27.h }[2], [x22], #0x2\n"
+    "tbz x8, #0, 43f\n"
+    "ld1 { v27.b }[6], [x22]\n"
     "b 43f\n"
     "40:"  // Oddments: Load (3, 3): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 43f\n"
-    "ld1 { v27.b }[4], [x28]\n"
+    "tbz x8, #0, 43f\n"
+    "ld1 { v27.b }[4], [x22]\n"
     "b 43f\n"
     "41:"  // Oddments: Load (3, 3): Bit 2: Unset
-    "tbz x4, #1, 42f\n"
-    "ld1 { v27.h }[0], [x28], #0x2\n"
-    "tbz x4, #0, 43f\n"
-    "ld1 { v27.b }[2], [x28]\n"
+    "tbz x8, #1, 42f\n"
+    "ld1 { v27.h }[0], [x22], #0x2\n"
+    "tbz x8, #0, 43f\n"
+    "ld1 { v27.b }[2], [x22]\n"
     "b 43f\n"
     "42:"  // Oddments: Load (3, 3): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 43f\n"
-    "ld1 { v27.b }[0], [x28]\n"
+    "tbz x8, #0, 43f\n"
+    "ld1 { v27.b }[0], [x22]\n"
     "43:"  // Oddments: Load (3, 3): Bit 2: End
-    "ldr x27, [x8, #0x80]\n"
     "usubl v27.8h, v27.8b, v12.8b\n"
-    "smlal v17.4s, v27.4h, v4.4h\n"
+    "ldr x21, [x12, #0x80]\n"
+    "smlal v23.4s, v27.4h, v4.4h\n"
     "smlal2 v18.4s, v27.8h, v4.8h\n"
-    "add x27, x27, x5\n"
-    "tbz x4, #2, 45f\n"
-    "ld1 { v28.s }[0], [x27], #0x4\n"
-    "tbz x4, #1, 44f\n"
-    "ld1 { v28.h }[2], [x27], #0x2\n"
-    "tbz x4, #0, 47f\n"
-    "ld1 { v28.b }[6], [x27]\n"
+    "add x21, x21, x15\n"
+    "tbz x8, #2, 45f\n"
+    "ld1 { v28.s }[0], [x21], #0x4\n"
+    "tbz x8, #1, 44f\n"
+    "ld1 { v28.h }[2], [x21], #0x2\n"
+    "tbz x8, #0, 47f\n"
+    "ld1 { v28.b }[6], [x21]\n"
     "b 47f\n"
     "44:"  // Oddments: Load (2, 3): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 47f\n"
-    "ld1 { v28.b }[4], [x27]\n"
+    "tbz x8, #0, 47f\n"
+    "ld1 { v28.b }[4], [x21]\n"
     "b 47f\n"
     "45:"  // Oddments: Load (2, 3): Bit 2: Unset
-    "tbz x4, #1, 46f\n"
-    "ld1 { v28.h }[0], [x27], #0x2\n"
-    "tbz x4, #0, 47f\n"
-    "ld1 { v28.b }[2], [x27]\n"
+    "tbz x8, #1, 46f\n"
+    "ld1 { v28.h }[0], [x21], #0x2\n"
+    "tbz x8, #0, 47f\n"
+    "ld1 { v28.b }[2], [x21]\n"
     "b 47f\n"
     "46:"  // Oddments: Load (2, 3): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 47f\n"
-    "ld1 { v28.b }[0], [x27]\n"
+    "tbz x8, #0, 47f\n"
+    "ld1 { v28.b }[0], [x21]\n"
     "47:"  // Oddments: Load (2, 3): Bit 2: End
-    "ldr x26, [x8, #0x88]\n"
     "usubl v28.8h, v28.8b, v12.8b\n"
-    "smlal v20.4s, v28.4h, v7.4h\n"
-    "smlal2 v23.4s, v28.8h, v7.8h\n"
-    "add x26, x26, x5\n"
-    "smlal v17.4s, v28.4h, v1.4h\n"
+    "ldr x20, [x12, #0x88]\n"
+    "smlal v9.4s, v28.4h, v7.4h\n"
+    "smlal2 v16.4s, v28.8h, v7.8h\n"
+    "smlal v23.4s, v28.4h, v1.4h\n"
     "smlal2 v18.4s, v28.8h, v1.8h\n"
-    "tbz x4, #2, 49f\n"
-    "ld1 { v26.s }[0], [x26], #0x4\n"
-    "tbz x4, #1, 48f\n"
-    "ld1 { v26.h }[2], [x26], #0x2\n"
-    "tbz x4, #0, 51f\n"
-    "ld1 { v26.b }[6], [x26]\n"
+    "add x20, x20, x15\n"
+    "tbz x8, #2, 49f\n"
+    "ld1 { v26.s }[0], [x20], #0x4\n"
+    "tbz x8, #1, 48f\n"
+    "ld1 { v26.h }[2], [x20], #0x2\n"
+    "tbz x8, #0, 51f\n"
+    "ld1 { v26.b }[6], [x20]\n"
     "b 51f\n"
     "48:"  // Oddments: Load (3, 4): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 51f\n"
-    "ld1 { v26.b }[4], [x26]\n"
+    "tbz x8, #0, 51f\n"
+    "ld1 { v26.b }[4], [x20]\n"
     "b 51f\n"
     "49:"  // Oddments: Load (3, 4): Bit 2: Unset
-    "tbz x4, #1, 50f\n"
-    "ld1 { v26.h }[0], [x26], #0x2\n"
-    "tbz x4, #0, 51f\n"
-    "ld1 { v26.b }[2], [x26]\n"
+    "tbz x8, #1, 50f\n"
+    "ld1 { v26.h }[0], [x20], #0x2\n"
+    "tbz x8, #0, 51f\n"
+    "ld1 { v26.b }[2], [x20]\n"
     "b 51f\n"
     "50:"  // Oddments: Load (3, 4): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 51f\n"
-    "ld1 { v26.b }[0], [x26]\n"
+    "tbz x8, #0, 51f\n"
+    "ld1 { v26.b }[0], [x20]\n"
     "51:"  // Oddments: Load (3, 4): Bit 2: End
-    "ldr x25, [x8, #0x90]\n"
     "usubl v26.8h, v26.8b, v12.8b\n"
-    "smlal v17.4s, v26.4h, v5.4h\n"
+    "ldr x23, [x12, #0x90]\n"
+    "smlal v23.4s, v26.4h, v5.4h\n"
     "smlal2 v18.4s, v26.8h, v5.8h\n"
-    "add x25, x25, x5\n"
-    "tbz x4, #2, 53f\n"
-    "ld1 { v25.s }[0], [x25], #0x4\n"
-    "tbz x4, #1, 52f\n"
-    "ld1 { v25.h }[2], [x25], #0x2\n"
-    "tbz x4, #0, 55f\n"
-    "ld1 { v25.b }[6], [x25]\n"
+    "add x23, x23, x15\n"
+    "tbz x8, #2, 53f\n"
+    "ld1 { v25.s }[0], [x23], #0x4\n"
+    "tbz x8, #1, 52f\n"
+    "ld1 { v25.h }[2], [x23], #0x2\n"
+    "tbz x8, #0, 55f\n"
+    "ld1 { v25.b }[6], [x23]\n"
     "b 55f\n"
     "52:"  // Oddments: Load (4, 0): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 55f\n"
-    "ld1 { v25.b }[4], [x25]\n"
+    "tbz x8, #0, 55f\n"
+    "ld1 { v25.b }[4], [x23]\n"
     "b 55f\n"
     "53:"  // Oddments: Load (4, 0): Bit 2: Unset
-    "tbz x4, #1, 54f\n"
-    "ld1 { v25.h }[0], [x25], #0x2\n"
-    "tbz x4, #0, 55f\n"
-    "ld1 { v25.b }[2], [x25]\n"
+    "tbz x8, #1, 54f\n"
+    "ld1 { v25.h }[0], [x23], #0x2\n"
+    "tbz x8, #0, 55f\n"
+    "ld1 { v25.b }[2], [x23]\n"
     "b 55f\n"
     "54:"  // Oddments: Load (4, 0): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 55f\n"
-    "ld1 { v25.b }[0], [x25]\n"
+    "tbz x8, #0, 55f\n"
+    "ld1 { v25.b }[0], [x23]\n"
     "55:"  // Oddments: Load (4, 0): Bit 2: End
-    "ldr x24, [x8, #0x98]\n"
     "usubl v25.8h, v25.8b, v12.8b\n"
-    "smlal v16.4s, v25.4h, v6.4h\n"
-    "smlal2 v22.4s, v25.8h, v6.8h\n"
-    "add x24, x24, x5\n"
-    "tbz x4, #2, 57f\n"
+    "ldr x24, [x12, #0x98]\n"
+    "smlal v22.4s, v25.4h, v6.4h\n"
+    "smlal2 v21.4s, v25.8h, v6.8h\n"
+    "add x24, x24, x15\n"
+    "tbz x8, #2, 57f\n"
     "ld1 { v29.s }[0], [x24], #0x4\n"
-    "tbz x4, #1, 56f\n"
+    "tbz x8, #1, 56f\n"
     "ld1 { v29.h }[2], [x24], #0x2\n"
-    "tbz x4, #0, 59f\n"
+    "tbz x8, #0, 59f\n"
     "ld1 { v29.b }[6], [x24]\n"
     "b 59f\n"
     "56:"  // Oddments: Load (2, 4): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 59f\n"
+    "tbz x8, #0, 59f\n"
     "ld1 { v29.b }[4], [x24]\n"
     "b 59f\n"
     "57:"  // Oddments: Load (2, 4): Bit 2: Unset
-    "tbz x4, #1, 58f\n"
+    "tbz x8, #1, 58f\n"
     "ld1 { v29.h }[0], [x24], #0x2\n"
-    "tbz x4, #0, 59f\n"
+    "tbz x8, #0, 59f\n"
     "ld1 { v29.b }[2], [x24]\n"
     "b 59f\n"
     "58:"  // Oddments: Load (2, 4): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 59f\n"
+    "tbz x8, #0, 59f\n"
     "ld1 { v29.b }[0], [x24]\n"
     "59:"  // Oddments: Load (2, 4): Bit 2: End
-    "ldr x23, [x8, #0xa0]\n"
     "usubl v29.8h, v29.8b, v12.8b\n"
-    "smlal v20.4s, v29.4h, v8.4h\n"
-    "smlal2 v23.4s, v29.8h, v8.8h\n"
-    "add x23, x23, x5\n"
-    "smlal v17.4s, v29.4h, v2.4h\n"
+    "ldr x19, [x12, #0xa0]\n"
+    "smlal v9.4s, v29.4h, v8.4h\n"
+    "smlal2 v16.4s, v29.8h, v8.8h\n"
+    "smlal v23.4s, v29.4h, v2.4h\n"
     "smlal2 v18.4s, v29.8h, v2.8h\n"
-    "tbz x4, #2, 61f\n"
-    "ld1 { v27.s }[0], [x23], #0x4\n"
-    "tbz x4, #1, 60f\n"
-    "ld1 { v27.h }[2], [x23], #0x2\n"
-    "tbz x4, #0, 63f\n"
-    "ld1 { v27.b }[6], [x23]\n"
+    "add x19, x19, x15\n"
+    "tbz x8, #2, 61f\n"
+    "ld1 { v27.s }[0], [x19], #0x4\n"
+    "tbz x8, #1, 60f\n"
+    "ld1 { v27.h }[2], [x19], #0x2\n"
+    "tbz x8, #0, 63f\n"
+    "ld1 { v27.b }[6], [x19]\n"
     "b 63f\n"
     "60:"  // Oddments: Load (4, 1): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 63f\n"
-    "ld1 { v27.b }[4], [x23]\n"
+    "tbz x8, #0, 63f\n"
+    "ld1 { v27.b }[4], [x19]\n"
     "b 63f\n"
     "61:"  // Oddments: Load (4, 1): Bit 2: Unset
-    "tbz x4, #1, 62f\n"
-    "ld1 { v27.h }[0], [x23], #0x2\n"
-    "tbz x4, #0, 63f\n"
-    "ld1 { v27.b }[2], [x23]\n"
+    "tbz x8, #1, 62f\n"
+    "ld1 { v27.h }[0], [x19], #0x2\n"
+    "tbz x8, #0, 63f\n"
+    "ld1 { v27.b }[2], [x19]\n"
     "b 63f\n"
     "62:"  // Oddments: Load (4, 1): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 63f\n"
-    "ld1 { v27.b }[0], [x23]\n"
+    "tbz x8, #0, 63f\n"
+    "ld1 { v27.b }[0], [x19]\n"
     "63:"  // Oddments: Load (4, 1): Bit 2: End
-    "ldr x22, [x8, #0xa8]\n"
     "usubl v27.8h, v27.8b, v12.8b\n"
-    "smlal v16.4s, v27.4h, v7.4h\n"
-    "smlal2 v22.4s, v27.8h, v7.8h\n"
-    "add x22, x22, x5\n"
-    "tbz x4, #2, 65f\n"
+    "ldr x22, [x12, #0xa8]\n"
+    "smlal v22.4s, v27.4h, v7.4h\n"
+    "smlal2 v21.4s, v27.8h, v7.8h\n"
+    "add x22, x22, x15\n"
+    "tbz x8, #2, 65f\n"
     "ld1 { v24.s }[0], [x22], #0x4\n"
-    "tbz x4, #1, 64f\n"
+    "tbz x8, #1, 64f\n"
     "ld1 { v24.h }[2], [x22], #0x2\n"
-    "tbz x4, #0, 67f\n"
+    "tbz x8, #0, 67f\n"
     "ld1 { v24.b }[6], [x22]\n"
     "b 67f\n"
     "64:"  // Oddments: Load (3, 2): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 67f\n"
+    "tbz x8, #0, 67f\n"
     "ld1 { v24.b }[4], [x22]\n"
     "b 67f\n"
     "65:"  // Oddments: Load (3, 2): Bit 2: Unset
-    "tbz x4, #1, 66f\n"
+    "tbz x8, #1, 66f\n"
     "ld1 { v24.h }[0], [x22], #0x2\n"
-    "tbz x4, #0, 67f\n"
+    "tbz x8, #0, 67f\n"
     "ld1 { v24.b }[2], [x22]\n"
     "b 67f\n"
     "66:"  // Oddments: Load (3, 2): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 67f\n"
+    "tbz x8, #0, 67f\n"
     "ld1 { v24.b }[0], [x22]\n"
     "67:"  // Oddments: Load (3, 2): Bit 2: End
-    "ldr x21, [x8, #0xb0]\n"
     "usubl v24.8h, v24.8b, v12.8b\n"
-    "smlal v16.4s, v24.4h, v5.4h\n"
-    "smlal2 v22.4s, v24.8h, v5.8h\n"
-    "add x21, x21, x5\n"
-    "smlal v17.4s, v24.4h, v3.4h\n"
+    "ldr x21, [x12, #0xb0]\n"
+    "smlal v22.4s, v24.4h, v5.4h\n"
+    "smlal2 v21.4s, v24.8h, v5.8h\n"
+    "smlal v23.4s, v24.4h, v3.4h\n"
     "smlal2 v18.4s, v24.8h, v3.8h\n"
-    "tbz x4, #2, 69f\n"
+    "add x21, x21, x15\n"
+    "tbz x8, #2, 69f\n"
     "ld1 { v26.s }[0], [x21], #0x4\n"
-    "tbz x4, #1, 68f\n"
+    "tbz x8, #1, 68f\n"
     "ld1 { v26.h }[2], [x21], #0x2\n"
-    "tbz x4, #0, 71f\n"
+    "tbz x8, #0, 71f\n"
     "ld1 { v26.b }[6], [x21]\n"
     "b 71f\n"
     "68:"  // Oddments: Load (4, 3): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 71f\n"
+    "tbz x8, #0, 71f\n"
     "ld1 { v26.b }[4], [x21]\n"
     "b 71f\n"
     "69:"  // Oddments: Load (4, 3): Bit 2: Unset
-    "tbz x4, #1, 70f\n"
+    "tbz x8, #1, 70f\n"
     "ld1 { v26.h }[0], [x21], #0x2\n"
-    "tbz x4, #0, 71f\n"
+    "tbz x8, #0, 71f\n"
     "ld1 { v26.b }[2], [x21]\n"
     "b 71f\n"
     "70:"  // Oddments: Load (4, 3): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 71f\n"
+    "tbz x8, #0, 71f\n"
     "ld1 { v26.b }[0], [x21]\n"
     "71:"  // Oddments: Load (4, 3): Bit 2: End
-    "ldr x20, [x8, #0xb8]\n"
     "usubl v26.8h, v26.8b, v12.8b\n"
-    "smlal v17.4s, v26.4h, v7.4h\n"
+    "ldr x20, [x12, #0xb8]\n"
+    "smlal v23.4s, v26.4h, v7.4h\n"
     "smlal2 v18.4s, v26.8h, v7.8h\n"
-    "add x20, x20, x5\n"
-    "tbz x4, #2, 73f\n"
+    "add x20, x20, x15\n"
+    "tbz x8, #2, 73f\n"
     "ld1 { v25.s }[0], [x20], #0x4\n"
-    "tbz x4, #1, 72f\n"
+    "tbz x8, #1, 72f\n"
     "ld1 { v25.h }[2], [x20], #0x2\n"
-    "tbz x4, #0, 75f\n"
+    "tbz x8, #0, 75f\n"
     "ld1 { v25.b }[6], [x20]\n"
     "b 75f\n"
     "72:"  // Oddments: Load (4, 2): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 75f\n"
+    "tbz x8, #0, 75f\n"
     "ld1 { v25.b }[4], [x20]\n"
     "b 75f\n"
     "73:"  // Oddments: Load (4, 2): Bit 2: Unset
-    "tbz x4, #1, 74f\n"
+    "tbz x8, #1, 74f\n"
     "ld1 { v25.h }[0], [x20], #0x2\n"
-    "tbz x4, #0, 75f\n"
+    "tbz x8, #0, 75f\n"
     "ld1 { v25.b }[2], [x20]\n"
     "b 75f\n"
     "74:"  // Oddments: Load (4, 2): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 75f\n"
+    "tbz x8, #0, 75f\n"
     "ld1 { v25.b }[0], [x20]\n"
     "75:"  // Oddments: Load (4, 2): Bit 2: End
-    "ldr x19, [x8, #0xc0]\n"
     "usubl v25.8h, v25.8b, v12.8b\n"
-    "smlal v16.4s, v25.4h, v8.4h\n"
-    "smlal2 v22.4s, v25.8h, v8.8h\n"
-    "add x19, x19, x5\n"
-    "smlal v17.4s, v25.4h, v6.4h\n"
+    "ldr x19, [x12, #0xc0]\n"
+    "smlal v22.4s, v25.4h, v8.4h\n"
+    "smlal2 v21.4s, v25.8h, v8.8h\n"
+    "smlal v23.4s, v25.4h, v6.4h\n"
     "smlal2 v18.4s, v25.8h, v6.8h\n"
-    "tbz x4, #2, 77f\n"
+    "add x19, x19, x15\n"
+    "tbz x8, #2, 77f\n"
     "ld1 { v29.s }[0], [x19], #0x4\n"
-    "tbz x4, #1, 76f\n"
+    "tbz x8, #1, 76f\n"
     "ld1 { v29.h }[2], [x19], #0x2\n"
-    "tbz x4, #0, 79f\n"
+    "tbz x8, #0, 79f\n"
     "ld1 { v29.b }[6], [x19]\n"
     "b 79f\n"
     "76:"  // Oddments: Load (4, 4): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 79f\n"
+    "tbz x8, #0, 79f\n"
     "ld1 { v29.b }[4], [x19]\n"
     "b 79f\n"
     "77:"  // Oddments: Load (4, 4): Bit 2: Unset
-    "tbz x4, #1, 78f\n"
+    "tbz x8, #1, 78f\n"
     "ld1 { v29.h }[0], [x19], #0x2\n"
-    "tbz x4, #0, 79f\n"
+    "tbz x8, #0, 79f\n"
     "ld1 { v29.b }[2], [x19]\n"
     "b 79f\n"
     "78:"  // Oddments: Load (4, 4): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 79f\n"
+    "tbz x8, #0, 79f\n"
     "ld1 { v29.b }[0], [x19]\n"
     "79:"  // Oddments: Load (4, 4): Bit 2: End
     "usubl v29.8h, v29.8b, v12.8b\n"
-    "smlal v17.4s, v29.4h, v8.4h\n"
+    "smlal v23.4s, v29.4h, v8.4h\n"
     "smlal2 v18.4s, v29.8h, v8.8h\n"
-    "tbz x4, #2, 81f\n"
-    "ld1 { v21.4s }, [x17], #0x10\n"
-    "ld1 { v30.4s }, [x15], #0x10\n"
-    "tbz x4, #1, 80f\n"
-    "ld1 { v31.d }[0], [x17], #0x8\n"
-    "ld1 { v9.d }[0], [x15], #0x8\n"
-    "tbz x4, #0, 83f\n"
-    "ld1 { v31.s }[2], [x17]\n"
-    "ld1 { v9.s }[2], [x15]\n"
+    "tbz x8, #2, 81f\n"
+    "ld1 { v19.4s }, [x13], #0x10\n"
+    "ld1 { v0.4s }, [x11], #0x10\n"
+    "tbz x8, #1, 80f\n"
+    "ld1 { v4.d }[0], [x13], #0x8\n"
+    "ld1 { v31.d }[0], [x11], #0x8\n"
+    "tbz x8, #0, 83f\n"
+    "ld1 { v4.s }[2], [x13]\n"
+    "ld1 { v31.s }[2], [x11]\n"
     "b 83f\n"
     "80:"  // Oddments: Load requant params: Bit 2: Bit 1: Unset
-    "tbz x4, #0, 83f\n"
-    "ld1 { v31.s }[0], [x17]\n"
-    "ld1 { v9.s }[0], [x15]\n"
+    "tbz x8, #0, 83f\n"
+    "ld1 { v4.s }[0], [x13]\n"
+    "ld1 { v31.s }[0], [x11]\n"
     "b 83f\n"
     "81:"  // Oddments: Load requant params: Bit 2: Unset
-    "tbz x4, #1, 82f\n"
-    "ld1 { v21.d }[0], [x17], #0x8\n"
-    "ld1 { v30.d }[0], [x15], #0x8\n"
-    "tbz x4, #0, 83f\n"
-    "ld1 { v21.s }[2], [x17]\n"
-    "ld1 { v30.s }[2], [x15]\n"
+    "tbz x8, #1, 82f\n"
+    "ld1 { v19.d }[0], [x13], #0x8\n"
+    "ld1 { v0.d }[0], [x11], #0x8\n"
+    "tbz x8, #0, 83f\n"
+    "ld1 { v19.s }[2], [x13]\n"
+    "ld1 { v0.s }[2], [x11]\n"
     "b 83f\n"
     "82:"  // Oddments: Load requant params: Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 83f\n"
-    "ld1 { v21.s }[0], [x17]\n"
-    "ld1 { v30.s }[0], [x15]\n"
+    "tbz x8, #0, 83f\n"
+    "ld1 { v19.s }[0], [x13]\n"
+    "ld1 { v0.s }[0], [x11]\n"
     "83:"  // Oddments: Load requant params: Bit 2: End
-    "sqrdmulh v15.4s, v15.4s, v21.4s\n"
-    "add x14, x14, x7\n"
-    "sqrdmulh v10.4s, v10.4s, v31.4s\n"
-    "add x13, x13, x7\n"
-    "sqrdmulh v20.4s, v20.4s, v21.4s\n"
-    "add x12, x12, x7\n"
-    "sqrdmulh v23.4s, v23.4s, v31.4s\n"
-    "add x11, x11, x7\n"
-    "sqrdmulh v16.4s, v16.4s, v21.4s\n"
-    "and v26.16b, v15.16b, v30.16b\n"
-    "sshr v26.4s, v26.4s, #0x1f\n"
-    "and v8.16b, v10.16b, v9.16b\n"
-    "and v4.16b, v20.16b, v30.16b\n"
-    "sshr v8.4s, v8.4s, #0x1f\n"
-    "and v2.16b, v23.16b, v9.16b\n"
-    "and v1.16b, v16.16b, v30.16b\n"
+    "sqdmulh v15.4s, v15.4s, v19.4s\n"
+    "sqdmulh v9.4s, v9.4s, v19.4s\n"
+    "add x10, x10, x14\n"
+    "add x9, x9, x14\n"
+    "sqdmulh v22.4s, v22.4s, v19.4s\n"
+    "sqdmulh v23.4s, v23.4s, v19.4s\n"
+    "add x28, x28, x14\n"
+    "add x27, x27, x14\n"
+    "and v30.16b, v15.16b, v0.16b\n"
+    "sqdmulh v10.4s, v10.4s, v4.4s\n"
+    "and v28.16b, v9.16b, v0.16b\n"
+    "sqdmulh v16.4s, v16.4s, v4.4s\n"
+    "and v29.16b, v22.16b, v0.16b\n"
+    "sqdmulh v21.4s, v21.4s, v4.4s\n"
+    "and v20.16b, v23.16b, v0.16b\n"
+    "sqdmulh v18.4s, v18.4s, v4.4s\n"
+    "sshr v30.4s, v30.4s, #0x1f\n"
+    "and v19.16b, v10.16b, v31.16b\n"
+    "sshr v28.4s, v28.4s, #0x1f\n"
+    "and v4.16b, v16.16b, v31.16b\n"
+    "sshr v29.4s, v29.4s, #0x1f\n"
+    "and v5.16b, v21.16b, v31.16b\n"
+    "sshr v20.4s, v20.4s, #0x1f\n"
+    "and v26.16b, v18.16b, v31.16b\n"
+    "sqadd v15.4s, v15.4s, v30.4s\n"
+    "sshr v19.4s, v19.4s, #0x1f\n"
+    "sqadd v9.4s, v9.4s, v28.4s\n"
     "sshr v4.4s, v4.4s, #0x1f\n"
-    "sqrdmulh v22.4s, v22.4s, v31.4s\n"
-    "sshr v2.4s, v2.4s, #0x1f\n"
-    "sqadd v15.4s, v15.4s, v26.4s\n"
-    "sqrdmulh v17.4s, v17.4s, v21.4s\n"
-    "sshr v1.4s, v1.4s, #0x1f\n"
-    "sqrdmulh v18.4s, v18.4s, v31.4s\n"
-    "sqadd v10.4s, v10.4s, v8.4s\n"
-    "sqadd v20.4s, v20.4s, v4.4s\n"
-    "srshl v15.4s, v15.4s, v30.4s\n"
-    "sqadd v23.4s, v23.4s, v2.4s\n"
-    "srshl v10.4s, v10.4s, v9.4s\n"
-    "srshl v20.4s, v20.4s, v30.4s\n"
-    "add v15.4s, v15.4s, v11.4s\n"
-    "srshl v23.4s, v23.4s, v9.4s\n"
-    "add v10.4s, v10.4s, v11.4s\n"
-    "smin v15.4s, v15.4s, v14.4s\n"
-    "add v20.4s, v20.4s, v11.4s\n"
-    "smin v10.4s, v10.4s, v14.4s\n"
-    "smax v15.4s, v15.4s, v19.4s\n"
-    "smin v20.4s, v20.4s, v14.4s\n"
-    "smax v10.4s, v10.4s, v19.4s\n"
-    "add v23.4s, v23.4s, v11.4s\n"
-    "smax v20.4s, v20.4s, v19.4s\n"
-    "uzp1 v15.16b, v15.16b, v10.16b\n"
-    "smin v23.4s, v23.4s, v14.4s\n"
+    "sqadd v22.4s, v22.4s, v29.4s\n"
+    "sshr v5.4s, v5.4s, #0x1f\n"
+    "sqadd v23.4s, v23.4s, v20.4s\n"
+    "sshr v26.4s, v26.4s, #0x1f\n"
+    "srshl v15.4s, v15.4s, v0.4s\n"
+    "sqadd v10.4s, v10.4s, v19.4s\n"
+    "srshl v9.4s, v9.4s, v0.4s\n"
+    "sqadd v16.4s, v16.4s, v4.4s\n"
+    "srshl v22.4s, v22.4s, v0.4s\n"
+    "sqadd v21.4s, v21.4s, v5.4s\n"
+    "srshl v23.4s, v23.4s, v0.4s\n"
+    "sqadd v18.4s, v18.4s, v26.4s\n"
+    "srshl v10.4s, v10.4s, v31.4s\n"
+    "sqxtn v15.4h, v15.4s\n"
+    "srshl v16.4s, v16.4s, v31.4s\n"
+    "sqxtn v9.4h, v9.4s\n"
+    "srshl v21.4s, v21.4s, v31.4s\n"
+    "sqxtn v22.4h, v22.4s\n"
+    "srshl v18.4s, v18.4s, v31.4s\n"
+    "sqxtn v23.4h, v23.4s\n"
+    "sqxtn2 v15.8h, v10.4s\n"
+    "sqxtn2 v9.8h, v16.4s\n"
+    "sqxtn2 v22.8h, v21.4s\n"
+    "sqxtn2 v23.8h, v18.4s\n"
+    "sqadd v15.8h, v15.8h, v11.8h\n"
+    "sqadd v9.8h, v9.8h, v11.8h\n"
+    "sqadd v22.8h, v22.8h, v11.8h\n"
+    "sqadd v23.8h, v23.8h, v11.8h\n"
+    "smax v15.8h, v15.8h, v17.8h\n"
+    "smax v9.8h, v9.8h, v17.8h\n"
+    "smax v22.8h, v22.8h, v17.8h\n"
+    "smax v23.8h, v23.8h, v17.8h\n"
+    "smin v15.8h, v15.8h, v14.8h\n"
+    "smin v9.8h, v9.8h, v14.8h\n"
+    "smin v22.8h, v22.8h, v14.8h\n"
+    "smin v23.8h, v23.8h, v14.8h\n"
     "uzp1 v15.16b, v15.16b, v15.16b\n"
-    "sqadd v16.4s, v16.4s, v1.4s\n"
-    "smax v23.4s, v23.4s, v19.4s\n"
-    "and v24.16b, v22.16b, v9.16b\n"
-    "sshr v24.4s, v24.4s, #0x1f\n"
-    "uzp1 v20.16b, v20.16b, v23.16b\n"
-    "srshl v16.4s, v16.4s, v30.4s\n"
-    "and v2.16b, v17.16b, v30.16b\n"
-    "sshr v2.4s, v2.4s, #0x1f\n"
-    "uzp1 v20.16b, v20.16b, v20.16b\n"
-    "add v16.4s, v16.4s, v11.4s\n"
-    "sqadd v22.4s, v22.4s, v24.4s\n"
-    "and v31.16b, v18.16b, v9.16b\n"
-    "sshr v31.4s, v31.4s, #0x1f\n"
-    "smin v16.4s, v16.4s, v14.4s\n"
-    "srshl v22.4s, v22.4s, v9.4s\n"
-    "sqadd v17.4s, v17.4s, v2.4s\n"
-    "smax v16.4s, v16.4s, v19.4s\n"
-    "add v22.4s, v22.4s, v11.4s\n"
-    "srshl v17.4s, v17.4s, v30.4s\n"
-    "sqadd v18.4s, v18.4s, v31.4s\n"
-    "smin v22.4s, v22.4s, v14.4s\n"
-    "add v17.4s, v17.4s, v11.4s\n"
-    "srshl v18.4s, v18.4s, v9.4s\n"
-    "smax v22.4s, v22.4s, v19.4s\n"
-    "smin v17.4s, v17.4s, v14.4s\n"
-    "uzp1 v16.16b, v16.16b, v22.16b\n"
-    "add v18.4s, v18.4s, v11.4s\n"
-    "uzp1 v16.16b, v16.16b, v16.16b\n"
-    "smax v17.4s, v17.4s, v19.4s\n"
-    "smin v18.4s, v18.4s, v14.4s\n"
-    "smax v18.4s, v18.4s, v19.4s\n"
-    "uzp1 v17.16b, v17.16b, v18.16b\n"
-    "uzp1 v17.16b, v17.16b, v17.16b\n"
-    "tbz x4, #2, 85f\n"
-    "st1 { v15.s }[0], [x14], #0x4\n"
-    "st1 { v20.s }[0], [x13], #0x4\n"
-    "st1 { v16.s }[0], [x12], #0x4\n"
-    "st1 { v17.s }[0], [x11], #0x4\n"
-    "tbz x4, #1, 84f\n"
-    "st1 { v15.h }[2], [x14], #0x2\n"
-    "st1 { v20.h }[2], [x13], #0x2\n"
-    "st1 { v16.h }[2], [x12], #0x2\n"
-    "st1 { v17.h }[2], [x11], #0x2\n"
-    "tbz x4, #0, 87f\n"
-    "st1 { v15.b }[6], [x14], #0x1\n"
-    "st1 { v20.b }[6], [x13], #0x1\n"
-    "st1 { v16.b }[6], [x12], #0x1\n"
-    "st1 { v17.b }[6], [x11], #0x1\n"
+    "uzp1 v9.16b, v9.16b, v9.16b\n"
+    "uzp1 v22.16b, v22.16b, v22.16b\n"
+    "uzp1 v23.16b, v23.16b, v23.16b\n"
+    "tbz x8, #2, 85f\n"
+    "st1 { v15.s }[0], [x10], #0x4\n"
+    "st1 { v9.s }[0], [x9], #0x4\n"
+    "st1 { v22.s }[0], [x28], #0x4\n"
+    "st1 { v23.s }[0], [x27], #0x4\n"
+    "tbz x8, #1, 84f\n"
+    "st1 { v15.h }[2], [x10], #0x2\n"
+    "st1 { v9.h }[2], [x9], #0x2\n"
+    "st1 { v22.h }[2], [x28], #0x2\n"
+    "st1 { v23.h }[2], [x27], #0x2\n"
+    "tbz x8, #0, 87f\n"
+    "st1 { v15.b }[6], [x10], #0x1\n"
+    "st1 { v9.b }[6], [x9], #0x1\n"
+    "st1 { v22.b }[6], [x28], #0x1\n"
+    "st1 { v23.b }[6], [x27], #0x1\n"
     "b 87f\n"
     "84:"  // Oddments: Bit 2: Bit 1: Unset
-    "tbz x4, #0, 87f\n"
-    "st1 { v15.b }[4], [x14], #0x1\n"
-    "st1 { v20.b }[4], [x13], #0x1\n"
-    "st1 { v16.b }[4], [x12], #0x1\n"
-    "st1 { v17.b }[4], [x11], #0x1\n"
+    "tbz x8, #0, 87f\n"
+    "st1 { v15.b }[4], [x10], #0x1\n"
+    "st1 { v9.b }[4], [x9], #0x1\n"
+    "st1 { v22.b }[4], [x28], #0x1\n"
+    "st1 { v23.b }[4], [x27], #0x1\n"
     "b 87f\n"
     "85:"  // Oddments: Bit 2: Unset
-    "tbz x4, #1, 86f\n"
-    "st1 { v15.h }[0], [x14], #0x2\n"
-    "st1 { v20.h }[0], [x13], #0x2\n"
-    "st1 { v16.h }[0], [x12], #0x2\n"
-    "st1 { v17.h }[0], [x11], #0x2\n"
-    "tbz x4, #0, 87f\n"
-    "st1 { v15.b }[2], [x14], #0x1\n"
-    "st1 { v20.b }[2], [x13], #0x1\n"
-    "st1 { v16.b }[2], [x12], #0x1\n"
-    "st1 { v17.b }[2], [x11], #0x1\n"
+    "tbz x8, #1, 86f\n"
+    "st1 { v15.h }[0], [x10], #0x2\n"
+    "st1 { v9.h }[0], [x9], #0x2\n"
+    "st1 { v22.h }[0], [x28], #0x2\n"
+    "st1 { v23.h }[0], [x27], #0x2\n"
+    "tbz x8, #0, 87f\n"
+    "st1 { v15.b }[2], [x10], #0x1\n"
+    "st1 { v9.b }[2], [x9], #0x1\n"
+    "st1 { v22.b }[2], [x28], #0x1\n"
+    "st1 { v23.b }[2], [x27], #0x1\n"
     "b 87f\n"
     "86:"  // Oddments: Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 87f\n"
-    "st1 { v15.b }[0], [x14], #0x1\n"
-    "st1 { v20.b }[0], [x13], #0x1\n"
-    "st1 { v16.b }[0], [x12], #0x1\n"
-    "st1 { v17.b }[0], [x11], #0x1\n"
+    "tbz x8, #0, 87f\n"
+    "st1 { v15.b }[0], [x10], #0x1\n"
+    "st1 { v9.b }[0], [x9], #0x1\n"
+    "st1 { v22.b }[0], [x28], #0x1\n"
+    "st1 { v23.b }[0], [x27], #0x1\n"
     "87:"  // Oddments: Bit 2: End
-
     "88:"  // End
-
     :
     : [offsetof_Params_bias] "I" (offsetof(Params, bias)), [offsetof_Params_inptrs] "I" (offsetof(Params, inptrs)), [offsetof_Params_n_channels] "I" (offsetof(Params, n_channels)), [offsetof_Params_outptrs] "I" (offsetof(Params, outptrs)), [offsetof_Params_requant] "I" (offsetof(Params, requant)), [offsetof_Params_requant_muls] "I" (offsetof(Params, requant_muls)), [offsetof_Params_requant_shifts] "I" (offsetof(Params, requant_shifts)), [offsetof_Params_weights] "I" (offsetof(Params, weights)), [offsetof_Requantize32_a_offset] "I" (offsetof(arm_gemm::Requantize32, a_offset)), [offsetof_Requantize32_b_offset] "I" (offsetof(arm_gemm::Requantize32, b_offset)), [offsetof_Requantize32_c_offset] "I" (offsetof(arm_gemm::Requantize32, c_offset)), [offsetof_Requantize32_maxval] "I" (offsetof(arm_gemm::Requantize32, maxval)), [offsetof_Requantize32_minval] "I" (offsetof(arm_gemm::Requantize32, minval)), [params] "r" (&params)
-    : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17", "x19", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"
+    : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x8", "x9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17", "x19", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"
   );
 }
 
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8q_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8q_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp
index 73de965..11e993c 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8q_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8q_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -36,37 +36,24 @@
 
 void a64_u8q_nhwc_5x5_s1_output2x2_mla_depthfirst_impl(unsigned int, const uint8_t *const *, const uint8_t *, const int32_t *, const arm_gemm::Requantize32 &, const int32_t *, const int32_t *, uint8_t *const *);
 
-struct a64_u8q_nhwc_5x5_s1_output2x2_mla_depthfirst
+class a64_u8q_nhwc_5x5_s1_output2x2_mla_depthfirst : public DepthwiseDepthfirstStrategy<uint8_t, uint8_t, uint8_t, int32_t>
 {
-  typedef int32_t bias_type;
-  typedef uint8_t input_type;
-  typedef uint8_t weight_type;
-  typedef uint8_t return_type;
+  using Parent = DepthwiseDepthfirstStrategy<uint8_t, uint8_t, uint8_t, int32_t>;
 
-  constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::None;
-
-  typedef void (*kern_type)(unsigned int, const uint8_t *const *, const uint8_t *, const int32_t *, const arm_gemm::Requantize32 &, const int32_t *, const int32_t *, uint8_t *const *);
-  typedef void (*parameter_packing_fn)(unsigned int, void *, const uint8_t *, size_t, size_t);
-  typedef size_t (*parameter_sizing_fn)(const DepthwiseArgs &);
-
+  public:
   constexpr static unsigned int kernel_rows = 5;
   constexpr static unsigned int kernel_cols = 5;
 
   constexpr static unsigned int stride_rows = 1;
   constexpr static unsigned int stride_cols = 1;
 
-  constexpr static unsigned int output_rows = 2;
-  constexpr static unsigned int output_cols = 2;
+  a64_u8q_nhwc_5x5_s1_output2x2_mla_depthfirst(const CPUInfo *) : Parent(2, 2, 5, 5, 1, 1) {}
 
-  constexpr static unsigned int input_rows = 6;
-  constexpr static unsigned int input_cols = 6;
+  arm_gemm::VLType get_vl_type(void) const override { return arm_gemm::VLType::None; }
 
-  constexpr static parameter_packing_fn pack_parameters = interleave_a64_u8q_5x5_mla::pack_parameters;
-  constexpr static parameter_sizing_fn get_packed_size = interleave_a64_u8q_5x5_mla::get_packed_size;
-
-  kern_type kernel = a64_u8q_nhwc_5x5_s1_output2x2_mla_depthfirst_impl;
-
-  a64_u8q_nhwc_5x5_s1_output2x2_mla_depthfirst(const CPUInfo *) {}
+  Parent::KernelType kernel = a64_u8q_nhwc_5x5_s1_output2x2_mla_depthfirst_impl;
+  Parent::KernelType get_kernel(void) const override { return kernel; }
+  unsigned int get_accumulator_depth_vl(void) const override { return 2; }
 };
 
 }  // namespace depthwise
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8q_nhwc_5x5_s1_output2x2_mla_depthfirst/generic.cpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8q_nhwc_5x5_s1_output2x2_mla_depthfirst/generic.cpp
index b42f29a..6934dff 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8q_nhwc_5x5_s1_output2x2_mla_depthfirst/generic.cpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8q_nhwc_5x5_s1_output2x2_mla_depthfirst/generic.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -46,7 +46,7 @@
   struct Params
   {
     long unsigned int n_channels;
-    const uint8_t *weights;
+    const void *weights;
     const int32_t *bias;
     const arm_gemm::Requantize32 *requant;
     const int32_t *const requant_muls;
@@ -57,7 +57,7 @@
     Params(
       long unsigned int n_channels,
       const uint8_t *const *inptrs_raw,
-      const uint8_t *const weights,
+      const void *const weights,
       const int32_t *const bias,
       const arm_gemm::Requantize32 &qp,
       const int32_t *const requant_muls,
@@ -111,2096 +111,2070 @@
                       requant_muls, requant_shifts, outptrs);
 
   __asm__ __volatile__(
-    "ldr x4, [%x[params], %[offsetof_Params_n_channels]]\n"
-    "mov x10, #0x0\n"
-    "ldr x3, [%x[params], %[offsetof_Params_weights]]\n"
-    "mov x1, #0x0\n"
-    "ldr x22, [%x[params], %[offsetof_Params_requant]]\n"
-    "add x25, %x[params], %[offsetof_Params_inptrs]\n"
-    "ldr x2, [%x[params], %[offsetof_Params_requant_muls]]\n"
-    "lsr x19, x4, #0x3\n"
-    "ldr x5, [%x[params], %[offsetof_Params_requant_shifts]]\n"
-    "add x13, x22, %[offsetof_Requantize32_a_offset]\n"
-    "ldr x21, [%x[params], %[offsetof_Params_outptrs]]\n"
-    "add x20, x22, %[offsetof_Requantize32_b_offset]\n"
-    "ld1r { v7.16b }, [x13]\n"
-    "add x8, x22, %[offsetof_Requantize32_c_offset]\n"
-    "ld1r { v13.16b }, [x20]\n"
-    "add x20, x22, %[offsetof_Requantize32_minval]\n"
-    "ld1r { v19.4s }, [x8]\n"
-    "add x8, x22, %[offsetof_Requantize32_maxval]\n"
-    "ld1r { v16.4s }, [x20]\n"
-    "ld1r { v12.4s }, [x8]\n"
-    "ldp x17, x16, [x21, #0x0]\n"
-    "ldp x6, x8, [x21, #0x10]\n"
-    "cbz x19, 3f\n"
-    "subs x19, x19, #0x1\n"
-    "ldr x12, [%x[params], %[offsetof_Params_bias]]\n"
-    "ldr q15, [x12, #0x0]\n"
-    "mov v18.16b, v15.16b\n"
-    "ldr q20, [x12, #0x10]\n"
-    "add x12, x12, #0x20\n"
-    "mov v11.16b, v15.16b\n"
-    "str x12, [%x[params], %[offsetof_Params_bias]]\n"
+    "ldr x10, [%x[params], %[offsetof_Params_requant]]\n"
+    "ldr x0, [%x[params], %[offsetof_Params_n_channels]]\n"
+    "add x17, x10, %[offsetof_Requantize32_a_offset]\n"
+    "add x9, x10, %[offsetof_Requantize32_b_offset]\n"
+    "ldr x25, [%x[params], %[offsetof_Params_outptrs]]\n"
+    "add x4, x10, %[offsetof_Requantize32_c_offset]\n"
+    "add x14, x10, %[offsetof_Requantize32_minval]\n"
+    "ldr x23, [%x[params], %[offsetof_Params_weights]]\n"
+    "add x5, x10, %[offsetof_Requantize32_maxval]\n"
+    "ld1r { v9.16b }, [x17]\n"
+    "ld1r { v14.16b }, [x9]\n"
+    "lsr x3, x0, #0x3\n"
+    "ld1r { v18.8h }, [x4]\n"
+    "ld1r { v11.8h }, [x14]\n"
+    "mov x24, #0x0\n"
+    "mov x22, #0x0\n"
+    "ld1r { v13.8h }, [x5]\n"
+    "ldr x10, [%x[params], %[offsetof_Params_requant_muls]]\n"
+    "add x20, %x[params], %[offsetof_Params_inptrs]\n"
+    "ldr x1, [%x[params], %[offsetof_Params_requant_shifts]]\n"
+    "ldp x16, x8, [x25, #0x0]\n"
+    "ldp x4, x7, [x25, #0x10]\n"
+    "cbz x3, 3f\n"
+    "ldr x19, [%x[params], %[offsetof_Params_bias]]\n"
+    "ldr q15, [x19, #0x0]\n"
+    "subs x3, x3, #0x1\n"
+    "mov v17.16b, v15.16b\n"
+    "ldr q16, [x19, #0x10]\n"
+    "add x19, x19, #0x20\n"
+    "str x19, [%x[params], %[offsetof_Params_bias]]\n"
+    "ldr d0, [x23, #0x0]\n"
+    "ldr d1, [x23, #0x8]\n"
+    "ldr d2, [x23, #0x10]\n"
+    "mov v8.16b, v16.16b\n"
     "mov v10.16b, v15.16b\n"
-    "ldr d0, [x3, #0x0]\n"
-    "usubl v0.8h, v0.8b, v13.8b\n"
-    "mov v5.16b, v20.16b\n"
-    "ldr d1, [x3, #0x8]\n"
-    "mov v8.16b, v20.16b\n"
-    "ldr d2, [x3, #0x10]\n"
-    "usubl v1.8h, v1.8b, v13.8b\n"
-    "mov v9.16b, v20.16b\n"
-    "ldr d3, [x3, #0x18]\n"
-    "ldr d4, [x3, #0x20]\n"
-    "usubl v2.8h, v2.8b, v13.8b\n"
-    "ldp x28, x27, [x25, #0x0]\n"
-    "usubl v3.8h, v3.8b, v13.8b\n"
-    "ldp x26, x13, [x25, #0x10]\n"
-    "usubl v4.8h, v4.8b, v13.8b\n"
-    "ldp x24, x23, [x25, #0x20]\n"
-    "ldp x22, x21, [x25, #0x30]\n"
-    "ldp x20, x0, [x25, #0x40]\n"
-    "ldr d31, [x28, x10]\n"
-    "usubl v31.8h, v31.8b, v7.8b\n"
-    "ldr d30, [x27, x10]\n"
-    "ldr d29, [x26, x10]\n"
-    "usubl v30.8h, v30.8b, v7.8b\n"
-    "ldr d28, [x13, x10]\n"
-    "ldr d27, [x24, x10]\n"
-    "usubl v29.8h, v29.8b, v7.8b\n"
-    "ldr d23, [x23, x10]\n"
-    "usubl v28.8h, v28.8b, v7.8b\n"
-    "ldr d25, [x22, x10]\n"
-    "ldr d24, [x21, x10]\n"
-    "usubl v27.8h, v27.8b, v7.8b\n"
-    "ldr d26, [x20, x10]\n"
-    "usubl v23.8h, v23.8b, v7.8b\n"
-    "ldr d22, [x0, x10]\n"
-    "usubl v25.8h, v25.8b, v7.8b\n"
-    "usubl v24.8h, v24.8b, v7.8b\n"
-    "usubl v26.8h, v26.8b, v7.8b\n"
-    "usubl v22.8h, v22.8b, v7.8b\n"
+    "ldr d3, [x23, #0x18]\n"
+    "ldr d4, [x23, #0x20]\n"
+    "mov v7.16b, v16.16b\n"
+    "mov v6.16b, v15.16b\n"
+    "ldp x28, x6, [x20, #0x0]\n"
+    "ldp x26, x25, [x20, #0x10]\n"
+    "mov v5.16b, v16.16b\n"
+    "usubl v0.8h, v0.8b, v14.8b\n"
+    "ldp x5, x2, [x20, #0x20]\n"
+    "ldp x27, x21, [x20, #0x30]\n"
+    "usubl v1.8h, v1.8b, v14.8b\n"
+    "usubl v2.8h, v2.8b, v14.8b\n"
+    "ldp x12, x19, [x20, #0x40]\n"
+    "ldr d31, [x28, x24]\n"
+    "usubl v3.8h, v3.8b, v14.8b\n"
+    "usubl v4.8h, v4.8b, v14.8b\n"
+    "ldr d30, [x6, x24]\n"
+    "ldr d29, [x26, x24]\n"
+    "usubl v31.8h, v31.8b, v9.8b\n"
+    "usubl v30.8h, v30.8b, v9.8b\n"
+    "ldr d28, [x25, x24]\n"
+    "ldr d27, [x5, x24]\n"
+    "usubl v29.8h, v29.8b, v9.8b\n"
+    "usubl v28.8h, v28.8b, v9.8b\n"
+    "ldr d23, [x2, x24]\n"
+    "ldr d25, [x27, x24]\n"
+    "usubl v27.8h, v27.8b, v9.8b\n"
+    "usubl v23.8h, v23.8b, v9.8b\n"
+    "ldr d24, [x21, x24]\n"
+    "ldr d26, [x12, x24]\n"
+    "usubl v25.8h, v25.8b, v9.8b\n"
+    "usubl v24.8h, v24.8b, v9.8b\n"
+    "ldr d22, [x19, x24]\n"
+    "usubl v26.8h, v26.8b, v9.8b\n"
+    "usubl v22.8h, v22.8b, v9.8b\n"
     "beq 2f\n"
     "1:"  // Loop
     "smlal v15.4s, v31.4h, v0.4h\n"
-    "ldr x20, [x25, #0x50]\n"
-    "subs x19, x19, #0x1\n"
-    "smlal2 v20.4s, v31.8h, v0.8h\n"
-    "ldr x28, [x25, #0x58]\n"
-    "smlal v18.4s, v30.4h, v0.4h\n"
-    "ldr x0, [x25, #0x60]\n"
-    "smlal2 v5.4s, v30.8h, v0.8h\n"
-    "ldr d31, [x20, x10]\n"
-    "usubl v31.8h, v31.8b, v7.8b\n"
-    "smlal v11.4s, v29.4h, v0.4h\n"
-    "ldr x7, [x25, #0x68]\n"
-    "smlal2 v8.4s, v29.8h, v0.8h\n"
-    "ldr x26, [x25, #0x70]\n"
-    "smlal v10.4s, v28.4h, v0.4h\n"
-    "ldr x23, [x25, #0x78]\n"
-    "smlal2 v9.4s, v28.8h, v0.8h\n"
-    "ldr d0, [x3, #0x28]\n"
-    "usubl v0.8h, v0.8b, v13.8b\n"
+    "smlal2 v16.4s, v31.8h, v0.8h\n"
+    "ldr x19, [x20, #0x50]\n"
+    "ldr d31, [x19, x24]\n"
+    "smlal v17.4s, v30.4h, v0.4h\n"
+    "smlal v10.4s, v29.4h, v0.4h\n"
+    "ldr x15, [x20, #0x58]\n"
+    "usubl v31.8h, v31.8b, v9.8b\n"
+    "smlal v6.4s, v28.4h, v0.4h\n"
+    "smlal2 v8.4s, v30.8h, v0.8h\n"
+    "ldr x19, [x20, #0x60]\n"
+    "ldr x27, [x20, #0x68]\n"
+    "smlal2 v7.4s, v29.8h, v0.8h\n"
     "smlal v15.4s, v30.4h, v1.4h\n"
-    "ldr x20, [x25, #0x80]\n"
-    "smlal2 v20.4s, v30.8h, v1.8h\n"
-    "ldr d30, [x28, x10]\n"
-    "usubl v30.8h, v30.8b, v7.8b\n"
-    "smlal v18.4s, v27.4h, v1.4h\n"
-    "ldr x22, [x25, #0x88]\n"
-    "smlal2 v5.4s, v27.8h, v1.8h\n"
-    "ldr x13, [x25, #0x90]\n"
-    "smlal v11.4s, v28.4h, v1.4h\n"
-    "ldr x21, [x25, #0x98]\n"
-    "smlal2 v8.4s, v28.8h, v1.8h\n"
-    "ldr x14, [x25, #0xa0]\n"
-    "smlal v10.4s, v23.4h, v1.4h\n"
-    "ldr x11, [x25, #0xa8]\n"
-    "smlal2 v9.4s, v23.8h, v1.8h\n"
-    "ldr d1, [x3, #0x30]\n"
-    "usubl v1.8h, v1.8b, v13.8b\n"
-    "smlal v15.4s, v27.4h, v2.4h\n"
-    "ldr x24, [x25, #0xb0]\n"
-    "smlal2 v20.4s, v27.8h, v2.8h\n"
-    "ldr d27, [x0, x10]\n"
-    "usubl v27.8h, v27.8b, v7.8b\n"
-    "smlal v18.4s, v25.4h, v2.4h\n"
-    "ldr x0, [x25, #0xb8]\n"
-    "smlal2 v5.4s, v25.8h, v2.8h\n"
-    "ldr x15, [x25, #0xc0]\n"
-    "smlal v11.4s, v23.4h, v2.4h\n"
-    "ldr x9, [x25, #0xc8]\n"
-    "smlal2 v8.4s, v23.8h, v2.8h\n"
-    "ldr x27, [x25, #0xd0]\n"
-    "smlal v10.4s, v31.4h, v2.4h\n"
-    "ldr x28, [x25, #0xd8]\n"
-    "smlal2 v9.4s, v31.8h, v2.8h\n"
-    "ldr d2, [x3, #0x38]\n"
-    "usubl v2.8h, v2.8b, v13.8b\n"
-    "smlal v15.4s, v25.4h, v3.4h\n"
-    "ldr q6, [x2, #0x0]\n"
-    "smlal2 v20.4s, v25.8h, v3.8h\n"
-    "ldr d25, [x7, x10]\n"
-    "usubl v25.8h, v25.8b, v7.8b\n"
-    "smlal v18.4s, v24.4h, v3.4h\n"
-    "ldr x12, [x25, #0xe0]\n"
-    "smlal2 v5.4s, v24.8h, v3.8h\n"
-    "ldr q21, [x5, #0x0]\n"
-    "smlal v11.4s, v31.4h, v3.4h\n"
-    "ldr q17, [x2, #0x10]\n"
-    "add x2, x2, #0x20\n"
-    "smlal2 v8.4s, v31.8h, v3.8h\n"
-    "ldr q14, [x5, #0x10]\n"
-    "add x5, x5, #0x20\n"
-    "smlal v10.4s, v30.4h, v3.4h\n"
-    "smlal2 v9.4s, v30.8h, v3.8h\n"
-    "ldr d3, [x3, #0x40]\n"
-    "usubl v3.8h, v3.8b, v13.8b\n"
-    "smlal v15.4s, v24.4h, v4.4h\n"
-    "smlal2 v20.4s, v24.8h, v4.8h\n"
-    "ldr d24, [x26, x10]\n"
-    "usubl v24.8h, v24.8b, v7.8b\n"
-    "smlal v18.4s, v27.4h, v4.4h\n"
-    "ldr x7, [x25, #0xe8]\n"
-    "smlal2 v5.4s, v27.8h, v4.8h\n"
-    "ldr d27, [x23, x10]\n"
-    "usubl v27.8h, v27.8b, v7.8b\n"
-    "smlal v11.4s, v30.4h, v4.4h\n"
-    "ldr x26, [x25, #0xf0]\n"
-    "smlal2 v8.4s, v30.8h, v4.8h\n"
-    "smlal v10.4s, v26.4h, v4.4h\n"
-    "smlal2 v9.4s, v26.8h, v4.8h\n"
-    "ldr d4, [x3, #0x48]\n"
-    "usubl v4.8h, v4.8b, v13.8b\n"
-    "smlal v15.4s, v29.4h, v0.4h\n"
-    "smlal2 v20.4s, v29.8h, v0.8h\n"
-    "smlal v18.4s, v28.4h, v0.4h\n"
+    "ldr x5, [x20, #0x70]\n"
+    "ldr x11, [x20, #0x78]\n"
+    "smlal2 v16.4s, v30.8h, v1.8h\n"
     "smlal2 v5.4s, v28.8h, v0.8h\n"
-    "smlal v11.4s, v22.4h, v0.4h\n"
-    "smlal2 v8.4s, v22.8h, v0.8h\n"
-    "smlal v10.4s, v25.4h, v0.4h\n"
-    "smlal2 v9.4s, v25.8h, v0.8h\n"
-    "ldr d0, [x3, #0x50]\n"
-    "usubl v0.8h, v0.8b, v13.8b\n"
-    "smlal v15.4s, v28.4h, v1.4h\n"
-    "smlal2 v20.4s, v28.8h, v1.8h\n"
-    "ldr d28, [x22, x10]\n"
-    "usubl v28.8h, v28.8b, v7.8b\n"
-    "smlal v18.4s, v23.4h, v1.4h\n"
-    "ldr x23, [x25, #0xf8]\n"
+    "ldr d30, [x15, x24]\n"
+    "usubl v30.8h, v30.8b, v9.8b\n"
+    "smlal v17.4s, v27.4h, v1.4h\n"
+    "smlal v10.4s, v28.4h, v1.4h\n"
+    "ldr d0, [x23, #0x28]\n"
+    "usubl v0.8h, v0.8b, v14.8b\n"
+    "smlal v6.4s, v23.4h, v1.4h\n"
+    "smlal2 v8.4s, v27.8h, v1.8h\n"
+    "ldr x12, [x20, #0x80]\n"
+    "ldr x26, [x20, #0x88]\n"
+    "smlal2 v7.4s, v28.8h, v1.8h\n"
+    "smlal v15.4s, v27.4h, v2.4h\n"
+    "ldr x14, [x20, #0x90]\n"
+    "ldr x15, [x20, #0x98]\n"
+    "smlal2 v16.4s, v27.8h, v2.8h\n"
     "smlal2 v5.4s, v23.8h, v1.8h\n"
-    "smlal v11.4s, v25.4h, v1.4h\n"
-    "smlal2 v8.4s, v25.8h, v1.8h\n"
-    "smlal v10.4s, v24.4h, v1.4h\n"
-    "smlal2 v9.4s, v24.8h, v1.8h\n"
-    "ldr d1, [x3, #0x58]\n"
-    "usubl v1.8h, v1.8b, v13.8b\n"
-    "smlal v15.4s, v23.4h, v2.4h\n"
-    "smlal2 v20.4s, v23.8h, v2.8h\n"
-    "ldr d23, [x20, x10]\n"
-    "usubl v23.8h, v23.8b, v7.8b\n"
-    "smlal v18.4s, v31.4h, v2.4h\n"
-    "ldr x22, [x25, #0x100]\n"
-    "smlal2 v5.4s, v31.8h, v2.8h\n"
-    "smlal v11.4s, v24.4h, v2.4h\n"
-    "smlal2 v8.4s, v24.8h, v2.8h\n"
-    "smlal v10.4s, v27.4h, v2.4h\n"
-    "smlal2 v9.4s, v27.8h, v2.8h\n"
-    "ldr d2, [x3, #0x60]\n"
-    "usubl v2.8h, v2.8b, v13.8b\n"
-    "smlal v15.4s, v31.4h, v3.4h\n"
-    "smlal2 v20.4s, v31.8h, v3.8h\n"
-    "ldr d31, [x13, x10]\n"
-    "usubl v31.8h, v31.8b, v7.8b\n"
-    "smlal v18.4s, v30.4h, v3.4h\n"
-    "ldr x20, [x25, #0x108]\n"
-    "smlal2 v5.4s, v30.8h, v3.8h\n"
-    "smlal v11.4s, v27.4h, v3.4h\n"
-    "smlal2 v8.4s, v27.8h, v3.8h\n"
-    "smlal v10.4s, v23.4h, v3.4h\n"
-    "smlal2 v9.4s, v23.8h, v3.8h\n"
-    "ldr d3, [x3, #0x68]\n"
-    "usubl v3.8h, v3.8b, v13.8b\n"
-    "smlal v15.4s, v30.4h, v4.4h\n"
-    "smlal2 v20.4s, v30.8h, v4.8h\n"
-    "ldr d30, [x21, x10]\n"
-    "usubl v30.8h, v30.8b, v7.8b\n"
-    "smlal v18.4s, v26.4h, v4.4h\n"
-    "ldr x13, [x25, #0x110]\n"
-    "smlal2 v5.4s, v26.8h, v4.8h\n"
-    "ldr d26, [x14, x10]\n"
-    "usubl v26.8h, v26.8b, v7.8b\n"
-    "smlal v11.4s, v23.4h, v4.4h\n"
-    "ldr x21, [x25, #0x118]\n"
-    "smlal2 v8.4s, v23.8h, v4.8h\n"
-    "smlal v10.4s, v28.4h, v4.4h\n"
-    "smlal2 v9.4s, v28.8h, v4.8h\n"
-    "ldr d4, [x3, #0x70]\n"
-    "usubl v4.8h, v4.8b, v13.8b\n"
-    "smlal v15.4s, v22.4h, v0.4h\n"
-    "smlal2 v20.4s, v22.8h, v0.8h\n"
-    "ldr d22, [x0, x10]\n"
-    "usubl v22.8h, v22.8b, v7.8b\n"
-    "smlal v18.4s, v25.4h, v0.4h\n"
-    "smlal2 v5.4s, v25.8h, v0.8h\n"
-    "smlal v11.4s, v31.4h, v0.4h\n"
-    "smlal2 v8.4s, v31.8h, v0.8h\n"
-    "smlal v10.4s, v30.4h, v0.4h\n"
-    "smlal2 v9.4s, v30.8h, v0.8h\n"
-    "ldr d0, [x3, #0x78]\n"
-    "usubl v0.8h, v0.8b, v13.8b\n"
-    "smlal v15.4s, v25.4h, v1.4h\n"
-    "smlal2 v20.4s, v25.8h, v1.8h\n"
-    "ldr d25, [x11, x10]\n"
-    "usubl v25.8h, v25.8b, v7.8b\n"
-    "smlal v18.4s, v24.4h, v1.4h\n"
-    "smlal2 v5.4s, v24.8h, v1.8h\n"
-    "smlal v11.4s, v30.4h, v1.4h\n"
-    "smlal2 v8.4s, v30.8h, v1.8h\n"
-    "smlal v10.4s, v26.4h, v1.4h\n"
-    "smlal2 v9.4s, v26.8h, v1.8h\n"
-    "ldr d1, [x3, #0x80]\n"
-    "usubl v1.8h, v1.8b, v13.8b\n"
-    "smlal v15.4s, v24.4h, v2.4h\n"
-    "smlal2 v20.4s, v24.8h, v2.8h\n"
-    "ldr d24, [x24, x10]\n"
-    "usubl v24.8h, v24.8b, v7.8b\n"
-    "smlal v18.4s, v27.4h, v2.4h\n"
-    "smlal2 v5.4s, v27.8h, v2.8h\n"
-    "smlal v11.4s, v26.4h, v2.4h\n"
-    "smlal2 v8.4s, v26.8h, v2.8h\n"
-    "smlal v10.4s, v25.4h, v2.4h\n"
-    "smlal2 v9.4s, v25.8h, v2.8h\n"
-    "ldr d2, [x3, #0x88]\n"
-    "usubl v2.8h, v2.8b, v13.8b\n"
-    "smlal v15.4s, v27.4h, v3.4h\n"
-    "smlal2 v20.4s, v27.8h, v3.8h\n"
-    "ldr d27, [x15, x10]\n"
-    "usubl v27.8h, v27.8b, v7.8b\n"
-    "smlal v18.4s, v23.4h, v3.4h\n"
-    "smlal2 v5.4s, v23.8h, v3.8h\n"
-    "smlal v11.4s, v25.4h, v3.4h\n"
-    "smlal2 v8.4s, v25.8h, v3.8h\n"
-    "smlal v10.4s, v24.4h, v3.4h\n"
-    "smlal2 v9.4s, v24.8h, v3.8h\n"
-    "ldr d3, [x3, #0x90]\n"
-    "usubl v3.8h, v3.8b, v13.8b\n"
-    "smlal v15.4s, v23.4h, v4.4h\n"
-    "smlal2 v20.4s, v23.8h, v4.8h\n"
-    "ldr d23, [x9, x10]\n"
-    "usubl v23.8h, v23.8b, v7.8b\n"
-    "smlal v18.4s, v28.4h, v4.4h\n"
-    "smlal2 v5.4s, v28.8h, v4.8h\n"
-    "ldr d28, [x12, x10]\n"
-    "usubl v28.8h, v28.8b, v7.8b\n"
-    "smlal v11.4s, v24.4h, v4.4h\n"
-    "smlal2 v8.4s, v24.8h, v4.8h\n"
-    "smlal v10.4s, v22.4h, v4.4h\n"
-    "smlal2 v9.4s, v22.8h, v4.8h\n"
-    "ldr d4, [x3, #0x98]\n"
-    "usubl v4.8h, v4.8b, v13.8b\n"
-    "smlal v15.4s, v31.4h, v0.4h\n"
-    "smlal2 v20.4s, v31.8h, v0.8h\n"
-    "ldr d31, [x27, x10]\n"
-    "usubl v31.8h, v31.8b, v7.8b\n"
-    "smlal v18.4s, v30.4h, v0.4h\n"
-    "smlal2 v5.4s, v30.8h, v0.8h\n"
-    "smlal v11.4s, v27.4h, v0.4h\n"
-    "smlal2 v8.4s, v27.8h, v0.8h\n"
-    "smlal v10.4s, v23.4h, v0.4h\n"
-    "smlal2 v9.4s, v23.8h, v0.8h\n"
-    "ldr d0, [x3, #0xa0]\n"
-    "usubl v0.8h, v0.8b, v13.8b\n"
-    "smlal v15.4s, v30.4h, v1.4h\n"
-    "smlal2 v20.4s, v30.8h, v1.8h\n"
-    "ldr d30, [x28, x10]\n"
-    "usubl v30.8h, v30.8b, v7.8b\n"
-    "smlal v18.4s, v26.4h, v1.4h\n"
-    "smlal2 v5.4s, v26.8h, v1.8h\n"
-    "smlal v11.4s, v23.4h, v1.4h\n"
-    "smlal2 v8.4s, v23.8h, v1.8h\n"
-    "smlal v10.4s, v31.4h, v1.4h\n"
-    "smlal2 v9.4s, v31.8h, v1.8h\n"
-    "ldr d1, [x3, #0xa8]\n"
-    "usubl v1.8h, v1.8b, v13.8b\n"
-    "smlal v15.4s, v26.4h, v2.4h\n"
-    "smlal2 v20.4s, v26.8h, v2.8h\n"
-    "ldr d26, [x7, x10]\n"
-    "usubl v26.8h, v26.8b, v7.8b\n"
-    "smlal v18.4s, v25.4h, v2.4h\n"
-    "smlal2 v5.4s, v25.8h, v2.8h\n"
-    "smlal v11.4s, v31.4h, v2.4h\n"
-    "smlal2 v8.4s, v31.8h, v2.8h\n"
-    "smlal v10.4s, v30.4h, v2.4h\n"
-    "smlal2 v9.4s, v30.8h, v2.8h\n"
-    "ldr d2, [x3, #0xb0]\n"
-    "usubl v2.8h, v2.8b, v13.8b\n"
+    "ldr d27, [x19, x24]\n"
+    "usubl v27.8h, v27.8b, v9.8b\n"
+    "smlal v17.4s, v25.4h, v2.4h\n"
+    "smlal v10.4s, v23.4h, v2.4h\n"
+    "ldr d1, [x23, #0x30]\n"
+    "usubl v1.8h, v1.8b, v14.8b\n"
+    "smlal v6.4s, v31.4h, v2.4h\n"
+    "smlal2 v8.4s, v25.8h, v2.8h\n"
+    "ldr x21, [x20, #0xa0]\n"
+    "ldr x2, [x20, #0xa8]\n"
+    "smlal2 v7.4s, v23.8h, v2.8h\n"
     "smlal v15.4s, v25.4h, v3.4h\n"
-    "smlal2 v20.4s, v25.8h, v3.8h\n"
-    "ldr d25, [x26, x10]\n"
-    "usubl v25.8h, v25.8b, v7.8b\n"
-    "smlal v18.4s, v24.4h, v3.4h\n"
-    "smlal2 v5.4s, v24.8h, v3.8h\n"
-    "smlal v11.4s, v30.4h, v3.4h\n"
-    "smlal2 v8.4s, v30.8h, v3.8h\n"
-    "smlal v10.4s, v28.4h, v3.4h\n"
-    "smlal2 v9.4s, v28.8h, v3.8h\n"
-    "ldr d3, [x3, #0xb8]\n"
-    "usubl v3.8h, v3.8b, v13.8b\n"
+    "ldr x13, [x20, #0xb0]\n"
+    "ldr x9, [x20, #0xb8]\n"
+    "smlal2 v16.4s, v25.8h, v3.8h\n"
+    "smlal2 v5.4s, v31.8h, v2.8h\n"
+    "ldr d25, [x27, x24]\n"
+    "usubl v25.8h, v25.8b, v9.8b\n"
+    "smlal v17.4s, v24.4h, v3.4h\n"
+    "smlal v10.4s, v31.4h, v3.4h\n"
+    "ldr d2, [x23, #0x38]\n"
+    "usubl v2.8h, v2.8b, v14.8b\n"
+    "smlal v6.4s, v30.4h, v3.4h\n"
+    "smlal2 v8.4s, v24.8h, v3.8h\n"
+    "ldr x19, [x20, #0xc0]\n"
+    "ldr x28, [x20, #0xc8]\n"
+    "smlal2 v7.4s, v31.8h, v3.8h\n"
     "smlal v15.4s, v24.4h, v4.4h\n"
-    "smlal2 v20.4s, v24.8h, v4.8h\n"
-    "ldr d24, [x23, x10]\n"
-    "usubl v24.8h, v24.8b, v7.8b\n"
-    "smlal v18.4s, v22.4h, v4.4h\n"
-    "smlal2 v5.4s, v22.8h, v4.8h\n"
-    "smlal v11.4s, v28.4h, v4.4h\n"
-    "smlal2 v8.4s, v28.8h, v4.8h\n"
-    "smlal v10.4s, v26.4h, v4.4h\n"
-    "smlal2 v9.4s, v26.8h, v4.8h\n"
-    "ldr d4, [x3, #0xc0]\n"
-    "add x3, x3, #0xc8\n"
-    "smlal v15.4s, v27.4h, v0.4h\n"
-    "usubl v4.8h, v4.8b, v13.8b\n"
-    "smlal2 v20.4s, v27.8h, v0.8h\n"
-    "ldr d27, [x22, x10]\n"
-    "smlal v18.4s, v23.4h, v0.4h\n"
-    "usubl v27.8h, v27.8b, v7.8b\n"
-    "smlal2 v5.4s, v23.8h, v0.8h\n"
-    "smlal v11.4s, v25.4h, v0.4h\n"
-    "smlal2 v8.4s, v25.8h, v0.8h\n"
-    "ldr d25, [x20, x10]\n"
-    "usubl v25.8h, v25.8b, v7.8b\n"
-    "smlal v10.4s, v24.4h, v0.4h\n"
-    "smlal2 v9.4s, v24.8h, v0.8h\n"
-    "smlal v15.4s, v23.4h, v1.4h\n"
-    "smlal2 v20.4s, v23.8h, v1.8h\n"
-    "smlal v18.4s, v31.4h, v1.4h\n"
-    "smlal2 v5.4s, v31.8h, v1.8h\n"
-    "smlal v11.4s, v24.4h, v1.4h\n"
-    "smlal2 v8.4s, v24.8h, v1.8h\n"
-    "ldr d24, [x13, x10]\n"
-    "usubl v24.8h, v24.8b, v7.8b\n"
-    "smlal v10.4s, v27.4h, v1.4h\n"
-    "smlal2 v9.4s, v27.8h, v1.8h\n"
-    "smlal v15.4s, v31.4h, v2.4h\n"
-    "smlal2 v20.4s, v31.8h, v2.8h\n"
-    "smlal v18.4s, v30.4h, v2.4h\n"
-    "smlal2 v5.4s, v30.8h, v2.8h\n"
-    "smlal v11.4s, v27.4h, v2.4h\n"
-    "smlal2 v8.4s, v27.8h, v2.8h\n"
-    "ldr d27, [x21, x10]\n"
-    "add x10, x10, #0x8\n"
-    "smlal v10.4s, v25.4h, v2.4h\n"
-    "usubl v27.8h, v27.8b, v7.8b\n"
-    "smlal2 v9.4s, v25.8h, v2.8h\n"
-    "smlal v15.4s, v30.4h, v3.4h\n"
-    "smlal2 v20.4s, v30.8h, v3.8h\n"
-    "smlal v18.4s, v28.4h, v3.4h\n"
-    "smlal2 v5.4s, v28.8h, v3.8h\n"
-    "smlal v11.4s, v25.4h, v3.4h\n"
-    "smlal2 v8.4s, v25.8h, v3.8h\n"
-    "smlal v10.4s, v24.4h, v3.4h\n"
-    "smlal2 v9.4s, v24.8h, v3.8h\n"
-    "smlal v15.4s, v28.4h, v4.4h\n"
-    "smlal2 v20.4s, v28.8h, v4.8h\n"
-    "smlal v18.4s, v26.4h, v4.4h\n"
+    "ldr x6, [x20, #0xd0]\n"
+    "ldr x27, [x20, #0xd8]\n"
+    "smlal2 v16.4s, v24.8h, v4.8h\n"
+    "smlal2 v5.4s, v30.8h, v3.8h\n"
+    "ldr d24, [x5, x24]\n"
+    "usubl v24.8h, v24.8b, v9.8b\n"
+    "smlal v17.4s, v27.4h, v4.4h\n"
+    "smlal v10.4s, v30.4h, v4.4h\n"
+    "ldr d3, [x23, #0x40]\n"
+    "usubl v3.8h, v3.8b, v14.8b\n"
+    "smlal v6.4s, v26.4h, v4.4h\n"
+    "smlal2 v8.4s, v27.8h, v4.8h\n"
+    "ldr d27, [x11, x24]\n"
+    "usubl v27.8h, v27.8b, v9.8b\n"
+    "smlal2 v7.4s, v30.8h, v4.8h\n"
+    "smlal v15.4s, v29.4h, v0.4h\n"
+    "ldr x11, [x20, #0xe0]\n"
+    "ldr x17, [x20, #0xe8]\n"
+    "smlal2 v16.4s, v29.8h, v0.8h\n"
     "smlal2 v5.4s, v26.8h, v4.8h\n"
-    "smlal v11.4s, v24.4h, v4.4h\n"
-    "smlal2 v8.4s, v24.8h, v4.8h\n"
-    "smlal v10.4s, v27.4h, v4.4h\n"
-    "smlal2 v9.4s, v27.8h, v4.8h\n"
-    "sqrdmulh v15.4s, v15.4s, v6.4s\n"
-    "sqrdmulh v20.4s, v20.4s, v17.4s\n"
-    "sqrdmulh v18.4s, v18.4s, v6.4s\n"
-    "sqrdmulh v5.4s, v5.4s, v17.4s\n"
-    "and v1.16b, v15.16b, v21.16b\n"
-    "sshr v1.4s, v1.4s, #0x1f\n"
-    "and v29.16b, v20.16b, v14.16b\n"
-    "and v3.16b, v18.16b, v21.16b\n"
-    "sshr v29.4s, v29.4s, #0x1f\n"
-    "and v2.16b, v5.16b, v14.16b\n"
-    "sqrdmulh v11.4s, v11.4s, v6.4s\n"
-    "sshr v3.4s, v3.4s, #0x1f\n"
-    "sqrdmulh v8.4s, v8.4s, v17.4s\n"
+    "ldr d4, [x23, #0x48]\n"
+    "usubl v4.8h, v4.8b, v14.8b\n"
+    "smlal v17.4s, v28.4h, v0.4h\n"
+    "smlal v10.4s, v22.4h, v0.4h\n"
+    "ldr x5, [x20, #0xf0]\n"
+    "ldr q12, [x10, #0x0]\n"
+    "smlal v6.4s, v25.4h, v0.4h\n"
+    "smlal2 v8.4s, v28.8h, v0.8h\n"
+    "ldr q19, [x1, #0x0]\n"
+    "ldr q20, [x10, #0x10]\n"
+    "smlal2 v7.4s, v22.8h, v0.8h\n"
+    "smlal v15.4s, v28.4h, v1.4h\n"
+    "ldr q29, [x1, #0x10]\n"
+    "subs x3, x3, #0x1\n"
+    "smlal2 v16.4s, v28.8h, v1.8h\n"
+    "smlal2 v5.4s, v25.8h, v0.8h\n"
+    "ldr d28, [x26, x24]\n"
+    "ldr d0, [x23, #0x50]\n"
+    "smlal v17.4s, v23.4h, v1.4h\n"
+    "smlal v10.4s, v25.4h, v1.4h\n"
+    "usubl v28.8h, v28.8b, v9.8b\n"
+    "ldr x25, [x20, #0xf8]\n"
+    "smlal v6.4s, v24.4h, v1.4h\n"
+    "smlal2 v8.4s, v23.8h, v1.8h\n"
+    "usubl v0.8h, v0.8b, v14.8b\n"
+    "add x10, x10, #0x20\n"
+    "smlal2 v7.4s, v25.8h, v1.8h\n"
+    "smlal v15.4s, v23.4h, v2.4h\n"
+    "add x1, x1, #0x20\n"
+    "smlal2 v16.4s, v23.8h, v2.8h\n"
+    "ldr d23, [x12, x24]\n"
+    "smlal2 v5.4s, v24.8h, v1.8h\n"
+    "usubl v23.8h, v23.8b, v9.8b\n"
+    "smlal v17.4s, v31.4h, v2.4h\n"
+    "smlal v10.4s, v24.4h, v2.4h\n"
+    "ldr d1, [x23, #0x58]\n"
+    "usubl v1.8h, v1.8b, v14.8b\n"
+    "smlal v6.4s, v27.4h, v2.4h\n"
+    "smlal2 v8.4s, v31.8h, v2.8h\n"
+    "ldr x26, [x20, #0x100]\n"
+    "smlal2 v7.4s, v24.8h, v2.8h\n"
+    "smlal v15.4s, v31.4h, v3.4h\n"
+    "smlal2 v16.4s, v31.8h, v3.8h\n"
+    "smlal2 v5.4s, v27.8h, v2.8h\n"
+    "ldr d31, [x14, x24]\n"
+    "usubl v31.8h, v31.8b, v9.8b\n"
+    "smlal v17.4s, v30.4h, v3.4h\n"
+    "smlal v10.4s, v27.4h, v3.4h\n"
+    "ldr d2, [x23, #0x60]\n"
+    "usubl v2.8h, v2.8b, v14.8b\n"
+    "smlal v6.4s, v23.4h, v3.4h\n"
+    "smlal2 v8.4s, v30.8h, v3.8h\n"
+    "ldr x12, [x20, #0x108]\n"
+    "smlal2 v7.4s, v27.8h, v3.8h\n"
+    "smlal v15.4s, v30.4h, v4.4h\n"
+    "smlal2 v16.4s, v30.8h, v4.8h\n"
+    "ldr d30, [x15, x24]\n"
+    "smlal2 v5.4s, v23.8h, v3.8h\n"
+    "usubl v30.8h, v30.8b, v9.8b\n"
+    "smlal v17.4s, v26.4h, v4.4h\n"
+    "smlal v10.4s, v23.4h, v4.4h\n"
+    "ldr d3, [x23, #0x68]\n"
+    "usubl v3.8h, v3.8b, v14.8b\n"
+    "smlal v6.4s, v28.4h, v4.4h\n"
+    "smlal2 v8.4s, v26.8h, v4.8h\n"
+    "ldr d26, [x21, x24]\n"
+    "usubl v26.8h, v26.8b, v9.8b\n"
+    "smlal2 v7.4s, v23.8h, v4.8h\n"
+    "smlal v15.4s, v22.4h, v0.4h\n"
+    "ldr x14, [x20, #0x110]\n"
+    "ldr x21, [x20, #0x118]\n"
+    "smlal2 v16.4s, v22.8h, v0.8h\n"
+    "smlal2 v5.4s, v28.8h, v4.8h\n"
+    "ldr d4, [x23, #0x70]\n"
+    "ldr d22, [x9, x24]\n"
+    "smlal v17.4s, v25.4h, v0.4h\n"
+    "smlal v10.4s, v31.4h, v0.4h\n"
+    "usubl v4.8h, v4.8b, v14.8b\n"
+    "smlal v6.4s, v30.4h, v0.4h\n"
+    "smlal2 v8.4s, v25.8h, v0.8h\n"
+    "usubl v22.8h, v22.8b, v9.8b\n"
+    "smlal2 v7.4s, v31.8h, v0.8h\n"
+    "smlal v15.4s, v25.4h, v1.4h\n"
+    "smlal2 v16.4s, v25.8h, v1.8h\n"
+    "ldr d25, [x2, x24]\n"
+    "smlal2 v5.4s, v30.8h, v0.8h\n"
+    "usubl v25.8h, v25.8b, v9.8b\n"
+    "smlal v17.4s, v24.4h, v1.4h\n"
+    "smlal v10.4s, v30.4h, v1.4h\n"
+    "ldr d0, [x23, #0x78]\n"
+    "usubl v0.8h, v0.8b, v14.8b\n"
+    "smlal v6.4s, v26.4h, v1.4h\n"
+    "smlal2 v8.4s, v24.8h, v1.8h\n"
+    "smlal2 v7.4s, v30.8h, v1.8h\n"
+    "smlal v15.4s, v24.4h, v2.4h\n"
+    "smlal2 v16.4s, v24.8h, v2.8h\n"
+    "ldr d24, [x13, x24]\n"
+    "smlal2 v5.4s, v26.8h, v1.8h\n"
+    "usubl v24.8h, v24.8b, v9.8b\n"
+    "smlal v17.4s, v27.4h, v2.4h\n"
+    "smlal v10.4s, v26.4h, v2.4h\n"
+    "ldr d1, [x23, #0x80]\n"
+    "usubl v1.8h, v1.8b, v14.8b\n"
+    "smlal v6.4s, v25.4h, v2.4h\n"
+    "smlal2 v8.4s, v27.8h, v2.8h\n"
+    "smlal2 v7.4s, v26.8h, v2.8h\n"
+    "smlal v15.4s, v27.4h, v3.4h\n"
+    "smlal2 v16.4s, v27.8h, v3.8h\n"
+    "smlal2 v5.4s, v25.8h, v2.8h\n"
+    "ldr d27, [x19, x24]\n"
+    "usubl v27.8h, v27.8b, v9.8b\n"
+    "smlal v17.4s, v23.4h, v3.4h\n"
+    "smlal v10.4s, v25.4h, v3.4h\n"
+    "ldr d2, [x23, #0x88]\n"
+    "usubl v2.8h, v2.8b, v14.8b\n"
+    "smlal v6.4s, v24.4h, v3.4h\n"
+    "smlal2 v8.4s, v23.8h, v3.8h\n"
+    "smlal2 v7.4s, v25.8h, v3.8h\n"
+    "smlal v15.4s, v23.4h, v4.4h\n"
+    "smlal2 v16.4s, v23.8h, v4.8h\n"
+    "ldr d23, [x28, x24]\n"
+    "smlal2 v5.4s, v24.8h, v3.8h\n"
+    "usubl v23.8h, v23.8b, v9.8b\n"
+    "smlal v17.4s, v28.4h, v4.4h\n"
+    "smlal v10.4s, v24.4h, v4.4h\n"
+    "ldr d3, [x23, #0x90]\n"
+    "usubl v3.8h, v3.8b, v14.8b\n"
+    "smlal v6.4s, v22.4h, v4.4h\n"
+    "smlal2 v8.4s, v28.8h, v4.8h\n"
+    "ldr d28, [x11, x24]\n"
+    "usubl v28.8h, v28.8b, v9.8b\n"
+    "smlal2 v7.4s, v24.8h, v4.8h\n"
+    "smlal v15.4s, v31.4h, v0.4h\n"
+    "smlal2 v16.4s, v31.8h, v0.8h\n"
+    "ldr d31, [x6, x24]\n"
+    "smlal2 v5.4s, v22.8h, v4.8h\n"
+    "usubl v31.8h, v31.8b, v9.8b\n"
+    "smlal v17.4s, v30.4h, v0.4h\n"
+    "smlal v10.4s, v27.4h, v0.4h\n"
+    "ldr d4, [x23, #0x98]\n"
+    "usubl v4.8h, v4.8b, v14.8b\n"
+    "smlal v6.4s, v23.4h, v0.4h\n"
+    "smlal2 v8.4s, v30.8h, v0.8h\n"
+    "smlal2 v7.4s, v27.8h, v0.8h\n"
+    "smlal v15.4s, v30.4h, v1.4h\n"
+    "smlal2 v16.4s, v30.8h, v1.8h\n"
+    "ldr d30, [x27, x24]\n"
+    "smlal2 v5.4s, v23.8h, v0.8h\n"
+    "usubl v30.8h, v30.8b, v9.8b\n"
+    "smlal v17.4s, v26.4h, v1.4h\n"
+    "smlal v10.4s, v23.4h, v1.4h\n"
+    "ldr d0, [x23, #0xa0]\n"
+    "usubl v0.8h, v0.8b, v14.8b\n"
+    "smlal v6.4s, v31.4h, v1.4h\n"
+    "smlal2 v8.4s, v26.8h, v1.8h\n"
+    "smlal2 v7.4s, v23.8h, v1.8h\n"
+    "smlal v15.4s, v26.4h, v2.4h\n"
+    "smlal2 v16.4s, v26.8h, v2.8h\n"
+    "smlal2 v5.4s, v31.8h, v1.8h\n"
+    "ldr d26, [x17, x24]\n"
+    "usubl v26.8h, v26.8b, v9.8b\n"
+    "smlal v17.4s, v25.4h, v2.4h\n"
+    "smlal v10.4s, v31.4h, v2.4h\n"
+    "ldr d1, [x23, #0xa8]\n"
+    "usubl v1.8h, v1.8b, v14.8b\n"
+    "smlal v6.4s, v30.4h, v2.4h\n"
+    "smlal2 v8.4s, v25.8h, v2.8h\n"
+    "smlal2 v7.4s, v31.8h, v2.8h\n"
+    "smlal v15.4s, v25.4h, v3.4h\n"
+    "smlal2 v16.4s, v25.8h, v3.8h\n"
+    "smlal2 v5.4s, v30.8h, v2.8h\n"
+    "ldr d25, [x5, x24]\n"
+    "usubl v25.8h, v25.8b, v9.8b\n"
+    "smlal v17.4s, v24.4h, v3.4h\n"
+    "smlal v10.4s, v30.4h, v3.4h\n"
+    "ldr d2, [x23, #0xb0]\n"
+    "usubl v2.8h, v2.8b, v14.8b\n"
+    "smlal v6.4s, v28.4h, v3.4h\n"
+    "smlal2 v8.4s, v24.8h, v3.8h\n"
+    "smlal2 v7.4s, v30.8h, v3.8h\n"
+    "smlal v15.4s, v24.4h, v4.4h\n"
+    "smlal2 v16.4s, v24.8h, v4.8h\n"
+    "ldr d24, [x25, x24]\n"
+    "smlal2 v5.4s, v28.8h, v3.8h\n"
+    "usubl v24.8h, v24.8b, v9.8b\n"
+    "smlal v17.4s, v22.4h, v4.4h\n"
+    "smlal v10.4s, v28.4h, v4.4h\n"
+    "ldr d3, [x23, #0xb8]\n"
+    "usubl v3.8h, v3.8b, v14.8b\n"
+    "smlal v6.4s, v26.4h, v4.4h\n"
+    "smlal2 v7.4s, v28.8h, v4.8h\n"
+    "smlal v15.4s, v27.4h, v0.4h\n"
+    "smlal2 v16.4s, v27.8h, v0.8h\n"
+    "ldr d27, [x26, x24]\n"
+    "usubl v27.8h, v27.8b, v9.8b\n"
+    "smlal2 v8.4s, v22.8h, v4.8h\n"
+    "smlal2 v5.4s, v26.8h, v4.8h\n"
+    "ldr d4, [x23, #0xc0]\n"
+    "usubl v4.8h, v4.8b, v14.8b\n"
+    "smlal v17.4s, v23.4h, v0.4h\n"
+    "smlal v10.4s, v25.4h, v0.4h\n"
+    "add x23, x23, #0xc8\n"
+    "smlal v6.4s, v24.4h, v0.4h\n"
+    "smlal2 v7.4s, v25.8h, v0.8h\n"
+    "ldr d25, [x12, x24]\n"
+    "usubl v25.8h, v25.8b, v9.8b\n"
+    "smlal2 v8.4s, v23.8h, v0.8h\n"
+    "smlal2 v5.4s, v24.8h, v0.8h\n"
+    "smlal v15.4s, v23.4h, v1.4h\n"
+    "smlal v17.4s, v31.4h, v1.4h\n"
+    "smlal v10.4s, v24.4h, v1.4h\n"
+    "smlal v6.4s, v27.4h, v1.4h\n"
+    "smlal2 v7.4s, v24.8h, v1.8h\n"
+    "ldr d24, [x14, x24]\n"
+    "smlal2 v16.4s, v23.8h, v1.8h\n"
+    "usubl v24.8h, v24.8b, v9.8b\n"
+    "smlal2 v8.4s, v31.8h, v1.8h\n"
+    "smlal2 v5.4s, v27.8h, v1.8h\n"
+    "smlal v15.4s, v31.4h, v2.4h\n"
+    "smlal v17.4s, v30.4h, v2.4h\n"
+    "smlal v10.4s, v27.4h, v2.4h\n"
+    "smlal v6.4s, v25.4h, v2.4h\n"
+    "smlal2 v7.4s, v27.8h, v2.8h\n"
+    "ldr d27, [x21, x24]\n"
+    "smlal2 v16.4s, v31.8h, v2.8h\n"
+    "usubl v27.8h, v27.8b, v9.8b\n"
+    "smlal2 v8.4s, v30.8h, v2.8h\n"
+    "smlal2 v5.4s, v25.8h, v2.8h\n"
+    "add x24, x24, #0x8\n"
+    "smlal v15.4s, v30.4h, v3.4h\n"
+    "smlal v17.4s, v28.4h, v3.4h\n"
+    "smlal v10.4s, v25.4h, v3.4h\n"
+    "smlal v6.4s, v24.4h, v3.4h\n"
+    "smlal2 v16.4s, v30.8h, v3.8h\n"
+    "smlal2 v8.4s, v28.8h, v3.8h\n"
+    "smlal2 v7.4s, v25.8h, v3.8h\n"
+    "smlal2 v5.4s, v24.8h, v3.8h\n"
+    "smlal v15.4s, v28.4h, v4.4h\n"
+    "smlal v17.4s, v26.4h, v4.4h\n"
+    "sqdmulh v15.4s, v15.4s, v12.4s\n"
+    "smlal v10.4s, v24.4h, v4.4h\n"
+    "smlal v6.4s, v27.4h, v4.4h\n"
+    "sqdmulh v17.4s, v17.4s, v12.4s\n"
+    "smlal2 v16.4s, v28.8h, v4.8h\n"
+    "smlal2 v8.4s, v26.8h, v4.8h\n"
+    "sqdmulh v10.4s, v10.4s, v12.4s\n"
+    "smlal2 v7.4s, v24.8h, v4.8h\n"
+    "smlal2 v5.4s, v27.8h, v4.8h\n"
+    "sqdmulh v6.4s, v6.4s, v12.4s\n"
+    "and v23.16b, v15.16b, v19.16b\n"
+    "sqdmulh v16.4s, v16.4s, v20.4s\n"
+    "and v22.16b, v17.16b, v19.16b\n"
+    "sqdmulh v8.4s, v8.4s, v20.4s\n"
+    "and v21.16b, v10.16b, v19.16b\n"
+    "sqdmulh v7.4s, v7.4s, v20.4s\n"
+    "and v26.16b, v6.16b, v19.16b\n"
+    "sqdmulh v5.4s, v5.4s, v20.4s\n"
+    "sshr v23.4s, v23.4s, #0x1f\n"
+    "and v4.16b, v16.16b, v29.16b\n"
+    "sshr v22.4s, v22.4s, #0x1f\n"
+    "and v2.16b, v8.16b, v29.16b\n"
+    "sshr v21.4s, v21.4s, #0x1f\n"
+    "and v3.16b, v7.16b, v29.16b\n"
+    "sshr v26.4s, v26.4s, #0x1f\n"
+    "and v25.16b, v5.16b, v29.16b\n"
+    "sqadd v15.4s, v15.4s, v23.4s\n"
+    "sshr v4.4s, v4.4s, #0x1f\n"
+    "sqadd v17.4s, v17.4s, v22.4s\n"
     "sshr v2.4s, v2.4s, #0x1f\n"
-    "sqadd v15.4s, v15.4s, v1.4s\n"
-    "sqrdmulh v10.4s, v10.4s, v6.4s\n"
-    "and v0.16b, v11.16b, v21.16b\n"
-    "sshr v0.4s, v0.4s, #0x1f\n"
-    "srshl v15.4s, v15.4s, v21.4s\n"
-    "sqadd v20.4s, v20.4s, v29.4s\n"
-    "sqadd v18.4s, v18.4s, v3.4s\n"
-    "sqadd v5.4s, v5.4s, v2.4s\n"
-    "and v27.16b, v8.16b, v14.16b\n"
-    "sshr v27.4s, v27.4s, #0x1f\n"
-    "add v15.4s, v15.4s, v19.4s\n"
-    "srshl v20.4s, v20.4s, v14.4s\n"
-    "srshl v18.4s, v18.4s, v21.4s\n"
-    "srshl v5.4s, v5.4s, v14.4s\n"
-    "smin v15.4s, v15.4s, v12.4s\n"
-    "add v20.4s, v20.4s, v19.4s\n"
-    "add v18.4s, v18.4s, v19.4s\n"
-    "smax v15.4s, v15.4s, v16.4s\n"
-    "smin v20.4s, v20.4s, v12.4s\n"
-    "smin v18.4s, v18.4s, v12.4s\n"
-    "add v5.4s, v5.4s, v19.4s\n"
-    "smax v20.4s, v20.4s, v16.4s\n"
-    "smax v18.4s, v18.4s, v16.4s\n"
-    "smin v5.4s, v5.4s, v12.4s\n"
-    "uzp1 v15.16b, v15.16b, v20.16b\n"
-    "sqadd v11.4s, v11.4s, v0.4s\n"
+    "sqadd v10.4s, v10.4s, v21.4s\n"
+    "sshr v3.4s, v3.4s, #0x1f\n"
+    "sqadd v6.4s, v6.4s, v26.4s\n"
+    "sshr v25.4s, v25.4s, #0x1f\n"
+    "srshl v15.4s, v15.4s, v19.4s\n"
+    "sqadd v16.4s, v16.4s, v4.4s\n"
+    "srshl v17.4s, v17.4s, v19.4s\n"
+    "sqadd v8.4s, v8.4s, v2.4s\n"
+    "srshl v10.4s, v10.4s, v19.4s\n"
+    "sqadd v7.4s, v7.4s, v3.4s\n"
+    "srshl v6.4s, v6.4s, v19.4s\n"
+    "sqadd v5.4s, v5.4s, v25.4s\n"
+    "srshl v16.4s, v16.4s, v29.4s\n"
+    "sqxtn v15.4h, v15.4s\n"
+    "srshl v8.4s, v8.4s, v29.4s\n"
+    "sqxtn v17.4h, v17.4s\n"
+    "srshl v7.4s, v7.4s, v29.4s\n"
+    "sqxtn v10.4h, v10.4s\n"
+    "srshl v5.4s, v5.4s, v29.4s\n"
+    "sqxtn v6.4h, v6.4s\n"
+    "sqxtn2 v15.8h, v16.4s\n"
+    "sqxtn2 v17.8h, v8.4s\n"
+    "sqxtn2 v10.8h, v7.4s\n"
+    "sqxtn2 v6.8h, v5.4s\n"
+    "sqadd v15.8h, v15.8h, v18.8h\n"
+    "sqadd v17.8h, v17.8h, v18.8h\n"
+    "sqadd v10.8h, v10.8h, v18.8h\n"
+    "sqadd v6.8h, v6.8h, v18.8h\n"
+    "smax v15.8h, v15.8h, v11.8h\n"
+    "smax v17.8h, v17.8h, v11.8h\n"
+    "smax v10.8h, v10.8h, v11.8h\n"
+    "smax v6.8h, v6.8h, v11.8h\n"
+    "smin v15.8h, v15.8h, v13.8h\n"
+    "smin v17.8h, v17.8h, v13.8h\n"
+    "smin v10.8h, v10.8h, v13.8h\n"
+    "smin v6.8h, v6.8h, v13.8h\n"
     "uzp1 v15.16b, v15.16b, v15.16b\n"
-    "str d15, [x17, x1]\n"
-    "smax v5.4s, v5.4s, v16.4s\n"
-    "sqadd v8.4s, v8.4s, v27.4s\n"
-    "srshl v11.4s, v11.4s, v21.4s\n"
-    "and v30.16b, v10.16b, v21.16b\n"
-    "sshr v30.4s, v30.4s, #0x1f\n"
-    "uzp1 v18.16b, v18.16b, v5.16b\n"
-    "add v11.4s, v11.4s, v19.4s\n"
-    "srshl v8.4s, v8.4s, v14.4s\n"
-    "uzp1 v18.16b, v18.16b, v18.16b\n"
-    "str d18, [x16, x1]\n"
-    "smin v11.4s, v11.4s, v12.4s\n"
-    "sqrdmulh v9.4s, v9.4s, v17.4s\n"
-    "add v8.4s, v8.4s, v19.4s\n"
-    "sqadd v10.4s, v10.4s, v30.4s\n"
-    "smax v11.4s, v11.4s, v16.4s\n"
-    "smin v8.4s, v8.4s, v12.4s\n"
-    "and v6.16b, v9.16b, v14.16b\n"
-    "sshr v6.4s, v6.4s, #0x1f\n"
-    "smax v8.4s, v8.4s, v16.4s\n"
-    "srshl v10.4s, v10.4s, v21.4s\n"
-    "uzp1 v11.16b, v11.16b, v8.16b\n"
-    "add v10.4s, v10.4s, v19.4s\n"
-    "uzp1 v11.16b, v11.16b, v11.16b\n"
-    "str d11, [x6, x1]\n"
-    "smin v10.4s, v10.4s, v12.4s\n"
-    "sqadd v9.4s, v9.4s, v6.4s\n"
-    "smax v10.4s, v10.4s, v16.4s\n"
-    "srshl v9.4s, v9.4s, v14.4s\n"
-    "add v9.4s, v9.4s, v19.4s\n"
-    "smin v9.4s, v9.4s, v12.4s\n"
-    "smax v9.4s, v9.4s, v16.4s\n"
-    "uzp1 v10.16b, v10.16b, v9.16b\n"
+    "uzp1 v17.16b, v17.16b, v17.16b\n"
+    "str d15, [x16, x22]\n"
     "uzp1 v10.16b, v10.16b, v10.16b\n"
-    "str d10, [x8, x1]\n"
-    "add x1, x1, #0x8\n"
-    "ldr x12, [%x[params], %[offsetof_Params_bias]]\n"
-    "ldr q15, [x12, #0x0]\n"
-    "mov v18.16b, v15.16b\n"
-    "ldr q20, [x12, #0x10]\n"
-    "add x12, x12, #0x20\n"
-    "mov v11.16b, v15.16b\n"
-    "str x12, [%x[params], %[offsetof_Params_bias]]\n"
+    "uzp1 v6.16b, v6.16b, v6.16b\n"
+    "str d17, [x8, x22]\n"
+    "str d10, [x4, x22]\n"
+    "str d6, [x7, x22]\n"
+    "ldr x19, [%x[params], %[offsetof_Params_bias]]\n"
+    "ldr q15, [x19, #0x0]\n"
+    "add x22, x22, #0x8\n"
+    "ldr q16, [x19, #0x10]\n"
+    "add x19, x19, #0x20\n"
+    "str x19, [%x[params], %[offsetof_Params_bias]]\n"
+    "ldr d0, [x23, #0x0]\n"
+    "ldr d1, [x23, #0x8]\n"
+    "ldr d2, [x23, #0x10]\n"
+    "mov v17.16b, v15.16b\n"
+    "mov v8.16b, v16.16b\n"
+    "ldr d3, [x23, #0x18]\n"
+    "ldr d4, [x23, #0x20]\n"
     "mov v10.16b, v15.16b\n"
-    "ldr d0, [x3, #0x0]\n"
-    "usubl v0.8h, v0.8b, v13.8b\n"
-    "mov v5.16b, v20.16b\n"
-    "ldr d1, [x3, #0x8]\n"
-    "mov v8.16b, v20.16b\n"
-    "ldr d2, [x3, #0x10]\n"
-    "usubl v1.8h, v1.8b, v13.8b\n"
-    "mov v9.16b, v20.16b\n"
-    "ldr d3, [x3, #0x18]\n"
-    "ldr d4, [x3, #0x20]\n"
-    "usubl v2.8h, v2.8b, v13.8b\n"
-    "ldp x28, x27, [x25, #0x0]\n"
-    "usubl v3.8h, v3.8b, v13.8b\n"
-    "ldp x26, x13, [x25, #0x10]\n"
-    "usubl v4.8h, v4.8b, v13.8b\n"
-    "ldp x24, x23, [x25, #0x20]\n"
-    "ldp x22, x21, [x25, #0x30]\n"
-    "ldp x20, x0, [x25, #0x40]\n"
-    "ldr d31, [x28, x10]\n"
-    "usubl v31.8h, v31.8b, v7.8b\n"
-    "ldr d30, [x27, x10]\n"
-    "ldr d29, [x26, x10]\n"
-    "usubl v30.8h, v30.8b, v7.8b\n"
-    "ldr d28, [x13, x10]\n"
-    "ldr d27, [x24, x10]\n"
-    "usubl v29.8h, v29.8b, v7.8b\n"
-    "ldr d23, [x23, x10]\n"
-    "usubl v28.8h, v28.8b, v7.8b\n"
-    "ldr d25, [x22, x10]\n"
-    "ldr d24, [x21, x10]\n"
-    "usubl v27.8h, v27.8b, v7.8b\n"
-    "ldr d26, [x20, x10]\n"
-    "usubl v23.8h, v23.8b, v7.8b\n"
-    "ldr d22, [x0, x10]\n"
-    "usubl v25.8h, v25.8b, v7.8b\n"
-    "usubl v24.8h, v24.8b, v7.8b\n"
-    "usubl v26.8h, v26.8b, v7.8b\n"
-    "usubl v22.8h, v22.8b, v7.8b\n"
+    "mov v7.16b, v16.16b\n"
+    "ldp x28, x6, [x20, #0x0]\n"
+    "ldp x26, x25, [x20, #0x10]\n"
+    "mov v6.16b, v15.16b\n"
+    "mov v5.16b, v16.16b\n"
+    "ldp x5, x2, [x20, #0x20]\n"
+    "ldp x27, x21, [x20, #0x30]\n"
+    "usubl v0.8h, v0.8b, v14.8b\n"
+    "usubl v1.8h, v1.8b, v14.8b\n"
+    "ldp x12, x19, [x20, #0x40]\n"
+    "ldr d31, [x28, x24]\n"
+    "usubl v2.8h, v2.8b, v14.8b\n"
+    "usubl v3.8h, v3.8b, v14.8b\n"
+    "ldr d30, [x6, x24]\n"
+    "ldr d29, [x26, x24]\n"
+    "usubl v4.8h, v4.8b, v14.8b\n"
+    "usubl v31.8h, v31.8b, v9.8b\n"
+    "ldr d28, [x25, x24]\n"
+    "ldr d27, [x5, x24]\n"
+    "usubl v30.8h, v30.8b, v9.8b\n"
+    "usubl v29.8h, v29.8b, v9.8b\n"
+    "ldr d23, [x2, x24]\n"
+    "ldr d25, [x27, x24]\n"
+    "usubl v28.8h, v28.8b, v9.8b\n"
+    "usubl v27.8h, v27.8b, v9.8b\n"
+    "ldr d24, [x21, x24]\n"
+    "ldr d26, [x12, x24]\n"
+    "usubl v23.8h, v23.8b, v9.8b\n"
+    "usubl v25.8h, v25.8b, v9.8b\n"
+    "ldr d22, [x19, x24]\n"
+    "usubl v24.8h, v24.8b, v9.8b\n"
+    "usubl v26.8h, v26.8b, v9.8b\n"
+    "usubl v22.8h, v22.8b, v9.8b\n"
     "bgt 1b\n"
     "2:"  // Tail
     "smlal v15.4s, v31.4h, v0.4h\n"
-    "ldr x20, [x25, #0x50]\n"
-    "tst x4, #0x7\n"
-    "smlal2 v20.4s, v31.8h, v0.8h\n"
-    "ldr x28, [x25, #0x58]\n"
-    "smlal v18.4s, v30.4h, v0.4h\n"
-    "ldr x0, [x25, #0x60]\n"
-    "smlal2 v5.4s, v30.8h, v0.8h\n"
-    "ldr d31, [x20, x10]\n"
-    "usubl v31.8h, v31.8b, v7.8b\n"
-    "smlal v11.4s, v29.4h, v0.4h\n"
-    "ldr x7, [x25, #0x68]\n"
-    "smlal2 v8.4s, v29.8h, v0.8h\n"
-    "ldr x26, [x25, #0x70]\n"
-    "smlal v10.4s, v28.4h, v0.4h\n"
-    "ldr x23, [x25, #0x78]\n"
-    "smlal2 v9.4s, v28.8h, v0.8h\n"
-    "ldr d0, [x3, #0x28]\n"
-    "usubl v0.8h, v0.8b, v13.8b\n"
+    "smlal2 v16.4s, v31.8h, v0.8h\n"
+    "ldr x19, [x20, #0x50]\n"
+    "ldr d31, [x19, x24]\n"
+    "smlal v17.4s, v30.4h, v0.4h\n"
+    "smlal v10.4s, v29.4h, v0.4h\n"
+    "ldr x15, [x20, #0x58]\n"
+    "usubl v31.8h, v31.8b, v9.8b\n"
+    "smlal v6.4s, v28.4h, v0.4h\n"
+    "smlal2 v8.4s, v30.8h, v0.8h\n"
+    "ldr x19, [x20, #0x60]\n"
+    "ldr x27, [x20, #0x68]\n"
+    "smlal2 v7.4s, v29.8h, v0.8h\n"
     "smlal v15.4s, v30.4h, v1.4h\n"
-    "ldr x20, [x25, #0x80]\n"
-    "smlal2 v20.4s, v30.8h, v1.8h\n"
-    "ldr d30, [x28, x10]\n"
-    "usubl v30.8h, v30.8b, v7.8b\n"
-    "smlal v18.4s, v27.4h, v1.4h\n"
-    "ldr x22, [x25, #0x88]\n"
-    "smlal2 v5.4s, v27.8h, v1.8h\n"
-    "ldr x13, [x25, #0x90]\n"
-    "smlal v11.4s, v28.4h, v1.4h\n"
-    "ldr x21, [x25, #0x98]\n"
-    "smlal2 v8.4s, v28.8h, v1.8h\n"
-    "ldr x14, [x25, #0xa0]\n"
-    "smlal v10.4s, v23.4h, v1.4h\n"
-    "ldr x11, [x25, #0xa8]\n"
-    "smlal2 v9.4s, v23.8h, v1.8h\n"
-    "ldr d1, [x3, #0x30]\n"
-    "usubl v1.8h, v1.8b, v13.8b\n"
-    "smlal v15.4s, v27.4h, v2.4h\n"
-    "ldr x24, [x25, #0xb0]\n"
-    "smlal2 v20.4s, v27.8h, v2.8h\n"
-    "ldr d27, [x0, x10]\n"
-    "usubl v27.8h, v27.8b, v7.8b\n"
-    "smlal v18.4s, v25.4h, v2.4h\n"
-    "ldr x0, [x25, #0xb8]\n"
-    "smlal2 v5.4s, v25.8h, v2.8h\n"
-    "ldr x15, [x25, #0xc0]\n"
-    "smlal v11.4s, v23.4h, v2.4h\n"
-    "ldr x9, [x25, #0xc8]\n"
-    "smlal2 v8.4s, v23.8h, v2.8h\n"
-    "ldr x27, [x25, #0xd0]\n"
-    "smlal v10.4s, v31.4h, v2.4h\n"
-    "ldr x28, [x25, #0xd8]\n"
-    "smlal2 v9.4s, v31.8h, v2.8h\n"
-    "ldr d2, [x3, #0x38]\n"
-    "usubl v2.8h, v2.8b, v13.8b\n"
-    "smlal v15.4s, v25.4h, v3.4h\n"
-    "ldr x12, [x25, #0xe0]\n"
-    "smlal2 v20.4s, v25.8h, v3.8h\n"
-    "ldr d25, [x7, x10]\n"
-    "usubl v25.8h, v25.8b, v7.8b\n"
-    "smlal v18.4s, v24.4h, v3.4h\n"
-    "ldr x7, [x25, #0xe8]\n"
-    "smlal2 v5.4s, v24.8h, v3.8h\n"
-    "ldr q6, [x2, #0x0]\n"
-    "smlal v11.4s, v31.4h, v3.4h\n"
-    "ldr q21, [x5, #0x0]\n"
-    "smlal2 v8.4s, v31.8h, v3.8h\n"
-    "ldr q17, [x2, #0x10]\n"
-    "add x2, x2, #0x20\n"
-    "smlal v10.4s, v30.4h, v3.4h\n"
-    "ldr q14, [x5, #0x10]\n"
-    "add x5, x5, #0x20\n"
-    "smlal2 v9.4s, v30.8h, v3.8h\n"
-    "ldr d3, [x3, #0x40]\n"
-    "usubl v3.8h, v3.8b, v13.8b\n"
-    "smlal v15.4s, v24.4h, v4.4h\n"
-    "smlal2 v20.4s, v24.8h, v4.8h\n"
-    "ldr d24, [x26, x10]\n"
-    "usubl v24.8h, v24.8b, v7.8b\n"
-    "smlal v18.4s, v27.4h, v4.4h\n"
-    "ldr x26, [x25, #0xf0]\n"
-    "smlal2 v5.4s, v27.8h, v4.8h\n"
-    "ldr d27, [x23, x10]\n"
-    "usubl v27.8h, v27.8b, v7.8b\n"
-    "smlal v11.4s, v30.4h, v4.4h\n"
-    "ldr x23, [x25, #0xf8]\n"
-    "smlal2 v8.4s, v30.8h, v4.8h\n"
-    "smlal v10.4s, v26.4h, v4.4h\n"
-    "smlal2 v9.4s, v26.8h, v4.8h\n"
-    "ldr d4, [x3, #0x48]\n"
-    "usubl v4.8h, v4.8b, v13.8b\n"
-    "smlal v15.4s, v29.4h, v0.4h\n"
-    "smlal2 v20.4s, v29.8h, v0.8h\n"
-    "smlal v18.4s, v28.4h, v0.4h\n"
+    "ldr x5, [x20, #0x70]\n"
+    "ldr x11, [x20, #0x78]\n"
+    "smlal2 v16.4s, v30.8h, v1.8h\n"
     "smlal2 v5.4s, v28.8h, v0.8h\n"
-    "smlal v11.4s, v22.4h, v0.4h\n"
-    "smlal2 v8.4s, v22.8h, v0.8h\n"
-    "smlal v10.4s, v25.4h, v0.4h\n"
-    "smlal2 v9.4s, v25.8h, v0.8h\n"
-    "ldr d0, [x3, #0x50]\n"
-    "usubl v0.8h, v0.8b, v13.8b\n"
-    "smlal v15.4s, v28.4h, v1.4h\n"
-    "smlal2 v20.4s, v28.8h, v1.8h\n"
-    "ldr d28, [x22, x10]\n"
-    "usubl v28.8h, v28.8b, v7.8b\n"
-    "smlal v18.4s, v23.4h, v1.4h\n"
-    "ldr x22, [x25, #0x100]\n"
+    "ldr d30, [x15, x24]\n"
+    "usubl v30.8h, v30.8b, v9.8b\n"
+    "smlal v17.4s, v27.4h, v1.4h\n"
+    "smlal v10.4s, v28.4h, v1.4h\n"
+    "ldr d0, [x23, #0x28]\n"
+    "usubl v0.8h, v0.8b, v14.8b\n"
+    "smlal v6.4s, v23.4h, v1.4h\n"
+    "smlal2 v8.4s, v27.8h, v1.8h\n"
+    "ldr x12, [x20, #0x80]\n"
+    "ldr x26, [x20, #0x88]\n"
+    "smlal2 v7.4s, v28.8h, v1.8h\n"
+    "smlal v15.4s, v27.4h, v2.4h\n"
+    "ldr x14, [x20, #0x90]\n"
+    "ldr x15, [x20, #0x98]\n"
+    "smlal2 v16.4s, v27.8h, v2.8h\n"
     "smlal2 v5.4s, v23.8h, v1.8h\n"
-    "smlal v11.4s, v25.4h, v1.4h\n"
-    "smlal2 v8.4s, v25.8h, v1.8h\n"
-    "smlal v10.4s, v24.4h, v1.4h\n"
-    "smlal2 v9.4s, v24.8h, v1.8h\n"
-    "ldr d1, [x3, #0x58]\n"
-    "usubl v1.8h, v1.8b, v13.8b\n"
-    "smlal v15.4s, v23.4h, v2.4h\n"
-    "smlal2 v20.4s, v23.8h, v2.8h\n"
-    "ldr d23, [x20, x10]\n"
-    "usubl v23.8h, v23.8b, v7.8b\n"
-    "smlal v18.4s, v31.4h, v2.4h\n"
-    "ldr x20, [x25, #0x108]\n"
-    "smlal2 v5.4s, v31.8h, v2.8h\n"
-    "smlal v11.4s, v24.4h, v2.4h\n"
-    "smlal2 v8.4s, v24.8h, v2.8h\n"
-    "smlal v10.4s, v27.4h, v2.4h\n"
-    "smlal2 v9.4s, v27.8h, v2.8h\n"
-    "ldr d2, [x3, #0x60]\n"
-    "usubl v2.8h, v2.8b, v13.8b\n"
-    "smlal v15.4s, v31.4h, v3.4h\n"
-    "smlal2 v20.4s, v31.8h, v3.8h\n"
-    "ldr d31, [x13, x10]\n"
-    "usubl v31.8h, v31.8b, v7.8b\n"
-    "smlal v18.4s, v30.4h, v3.4h\n"
-    "ldr x13, [x25, #0x110]\n"
-    "smlal2 v5.4s, v30.8h, v3.8h\n"
-    "smlal v11.4s, v27.4h, v3.4h\n"
-    "smlal2 v8.4s, v27.8h, v3.8h\n"
-    "smlal v10.4s, v23.4h, v3.4h\n"
-    "smlal2 v9.4s, v23.8h, v3.8h\n"
-    "ldr d3, [x3, #0x68]\n"
-    "usubl v3.8h, v3.8b, v13.8b\n"
-    "smlal v15.4s, v30.4h, v4.4h\n"
-    "smlal2 v20.4s, v30.8h, v4.8h\n"
-    "ldr d30, [x21, x10]\n"
-    "usubl v30.8h, v30.8b, v7.8b\n"
-    "smlal v18.4s, v26.4h, v4.4h\n"
-    "ldr x21, [x25, #0x118]\n"
-    "smlal2 v5.4s, v26.8h, v4.8h\n"
-    "ldr d26, [x14, x10]\n"
-    "usubl v26.8h, v26.8b, v7.8b\n"
-    "smlal v11.4s, v23.4h, v4.4h\n"
-    "smlal2 v8.4s, v23.8h, v4.8h\n"
-    "smlal v10.4s, v28.4h, v4.4h\n"
-    "smlal2 v9.4s, v28.8h, v4.8h\n"
-    "ldr d4, [x3, #0x70]\n"
-    "usubl v4.8h, v4.8b, v13.8b\n"
-    "smlal v15.4s, v22.4h, v0.4h\n"
-    "smlal2 v20.4s, v22.8h, v0.8h\n"
-    "ldr d22, [x0, x10]\n"
-    "usubl v22.8h, v22.8b, v7.8b\n"
-    "smlal v18.4s, v25.4h, v0.4h\n"
-    "smlal2 v5.4s, v25.8h, v0.8h\n"
-    "smlal v11.4s, v31.4h, v0.4h\n"
-    "smlal2 v8.4s, v31.8h, v0.8h\n"
-    "smlal v10.4s, v30.4h, v0.4h\n"
-    "smlal2 v9.4s, v30.8h, v0.8h\n"
-    "ldr d0, [x3, #0x78]\n"
-    "usubl v0.8h, v0.8b, v13.8b\n"
-    "smlal v15.4s, v25.4h, v1.4h\n"
-    "smlal2 v20.4s, v25.8h, v1.8h\n"
-    "ldr d25, [x11, x10]\n"
-    "usubl v25.8h, v25.8b, v7.8b\n"
-    "smlal v18.4s, v24.4h, v1.4h\n"
-    "smlal2 v5.4s, v24.8h, v1.8h\n"
-    "smlal v11.4s, v30.4h, v1.4h\n"
-    "smlal2 v8.4s, v30.8h, v1.8h\n"
-    "smlal v10.4s, v26.4h, v1.4h\n"
-    "smlal2 v9.4s, v26.8h, v1.8h\n"
-    "ldr d1, [x3, #0x80]\n"
-    "usubl v1.8h, v1.8b, v13.8b\n"
-    "smlal v15.4s, v24.4h, v2.4h\n"
-    "smlal2 v20.4s, v24.8h, v2.8h\n"
-    "ldr d24, [x24, x10]\n"
-    "usubl v24.8h, v24.8b, v7.8b\n"
-    "smlal v18.4s, v27.4h, v2.4h\n"
-    "smlal2 v5.4s, v27.8h, v2.8h\n"
-    "smlal v11.4s, v26.4h, v2.4h\n"
-    "smlal2 v8.4s, v26.8h, v2.8h\n"
-    "smlal v10.4s, v25.4h, v2.4h\n"
-    "smlal2 v9.4s, v25.8h, v2.8h\n"
-    "ldr d2, [x3, #0x88]\n"
-    "usubl v2.8h, v2.8b, v13.8b\n"
-    "smlal v15.4s, v27.4h, v3.4h\n"
-    "smlal2 v20.4s, v27.8h, v3.8h\n"
-    "ldr d27, [x15, x10]\n"
-    "usubl v27.8h, v27.8b, v7.8b\n"
-    "smlal v18.4s, v23.4h, v3.4h\n"
-    "smlal2 v5.4s, v23.8h, v3.8h\n"
-    "smlal v11.4s, v25.4h, v3.4h\n"
-    "smlal2 v8.4s, v25.8h, v3.8h\n"
-    "smlal v10.4s, v24.4h, v3.4h\n"
-    "smlal2 v9.4s, v24.8h, v3.8h\n"
-    "ldr d3, [x3, #0x90]\n"
-    "usubl v3.8h, v3.8b, v13.8b\n"
-    "smlal v15.4s, v23.4h, v4.4h\n"
-    "smlal2 v20.4s, v23.8h, v4.8h\n"
-    "ldr d23, [x9, x10]\n"
-    "usubl v23.8h, v23.8b, v7.8b\n"
-    "smlal v18.4s, v28.4h, v4.4h\n"
-    "smlal2 v5.4s, v28.8h, v4.8h\n"
-    "ldr d28, [x12, x10]\n"
-    "usubl v28.8h, v28.8b, v7.8b\n"
-    "smlal v11.4s, v24.4h, v4.4h\n"
-    "smlal2 v8.4s, v24.8h, v4.8h\n"
-    "smlal v10.4s, v22.4h, v4.4h\n"
-    "smlal2 v9.4s, v22.8h, v4.8h\n"
-    "ldr d4, [x3, #0x98]\n"
-    "usubl v4.8h, v4.8b, v13.8b\n"
-    "smlal v15.4s, v31.4h, v0.4h\n"
-    "smlal2 v20.4s, v31.8h, v0.8h\n"
-    "ldr d31, [x27, x10]\n"
-    "usubl v31.8h, v31.8b, v7.8b\n"
-    "smlal v18.4s, v30.4h, v0.4h\n"
-    "smlal2 v5.4s, v30.8h, v0.8h\n"
-    "smlal v11.4s, v27.4h, v0.4h\n"
-    "smlal2 v8.4s, v27.8h, v0.8h\n"
-    "smlal v10.4s, v23.4h, v0.4h\n"
-    "smlal2 v9.4s, v23.8h, v0.8h\n"
-    "ldr d0, [x3, #0xa0]\n"
-    "usubl v0.8h, v0.8b, v13.8b\n"
-    "smlal v15.4s, v30.4h, v1.4h\n"
-    "smlal2 v20.4s, v30.8h, v1.8h\n"
-    "ldr d30, [x28, x10]\n"
-    "usubl v30.8h, v30.8b, v7.8b\n"
-    "smlal v18.4s, v26.4h, v1.4h\n"
-    "smlal2 v5.4s, v26.8h, v1.8h\n"
-    "smlal v11.4s, v23.4h, v1.4h\n"
-    "smlal2 v8.4s, v23.8h, v1.8h\n"
-    "smlal v10.4s, v31.4h, v1.4h\n"
-    "smlal2 v9.4s, v31.8h, v1.8h\n"
-    "ldr d1, [x3, #0xa8]\n"
-    "usubl v1.8h, v1.8b, v13.8b\n"
-    "smlal v15.4s, v26.4h, v2.4h\n"
-    "smlal2 v20.4s, v26.8h, v2.8h\n"
-    "ldr d26, [x7, x10]\n"
-    "usubl v26.8h, v26.8b, v7.8b\n"
-    "smlal v18.4s, v25.4h, v2.4h\n"
-    "smlal2 v5.4s, v25.8h, v2.8h\n"
-    "smlal v11.4s, v31.4h, v2.4h\n"
-    "smlal2 v8.4s, v31.8h, v2.8h\n"
-    "smlal v10.4s, v30.4h, v2.4h\n"
-    "smlal2 v9.4s, v30.8h, v2.8h\n"
-    "ldr d2, [x3, #0xb0]\n"
-    "usubl v2.8h, v2.8b, v13.8b\n"
+    "ldr d27, [x19, x24]\n"
+    "usubl v27.8h, v27.8b, v9.8b\n"
+    "smlal v17.4s, v25.4h, v2.4h\n"
+    "smlal v10.4s, v23.4h, v2.4h\n"
+    "ldr d1, [x23, #0x30]\n"
+    "usubl v1.8h, v1.8b, v14.8b\n"
+    "smlal v6.4s, v31.4h, v2.4h\n"
+    "smlal2 v8.4s, v25.8h, v2.8h\n"
+    "ldr x21, [x20, #0xa0]\n"
+    "ldr x2, [x20, #0xa8]\n"
+    "smlal2 v7.4s, v23.8h, v2.8h\n"
     "smlal v15.4s, v25.4h, v3.4h\n"
-    "smlal2 v20.4s, v25.8h, v3.8h\n"
-    "ldr d25, [x26, x10]\n"
-    "usubl v25.8h, v25.8b, v7.8b\n"
-    "smlal v18.4s, v24.4h, v3.4h\n"
-    "smlal2 v5.4s, v24.8h, v3.8h\n"
-    "smlal v11.4s, v30.4h, v3.4h\n"
-    "smlal2 v8.4s, v30.8h, v3.8h\n"
-    "smlal v10.4s, v28.4h, v3.4h\n"
-    "smlal2 v9.4s, v28.8h, v3.8h\n"
-    "ldr d3, [x3, #0xb8]\n"
-    "usubl v3.8h, v3.8b, v13.8b\n"
+    "ldr x13, [x20, #0xb0]\n"
+    "ldr x9, [x20, #0xb8]\n"
+    "smlal2 v16.4s, v25.8h, v3.8h\n"
+    "smlal2 v5.4s, v31.8h, v2.8h\n"
+    "ldr d25, [x27, x24]\n"
+    "usubl v25.8h, v25.8b, v9.8b\n"
+    "smlal v17.4s, v24.4h, v3.4h\n"
+    "smlal v10.4s, v31.4h, v3.4h\n"
+    "ldr d2, [x23, #0x38]\n"
+    "usubl v2.8h, v2.8b, v14.8b\n"
+    "smlal v6.4s, v30.4h, v3.4h\n"
+    "smlal2 v8.4s, v24.8h, v3.8h\n"
+    "ldr x19, [x20, #0xc0]\n"
+    "ldr x28, [x20, #0xc8]\n"
+    "smlal2 v7.4s, v31.8h, v3.8h\n"
     "smlal v15.4s, v24.4h, v4.4h\n"
-    "smlal2 v20.4s, v24.8h, v4.8h\n"
-    "ldr d24, [x23, x10]\n"
-    "usubl v24.8h, v24.8b, v7.8b\n"
-    "smlal v18.4s, v22.4h, v4.4h\n"
-    "smlal2 v5.4s, v22.8h, v4.8h\n"
-    "smlal v11.4s, v28.4h, v4.4h\n"
-    "smlal2 v8.4s, v28.8h, v4.8h\n"
-    "smlal v10.4s, v26.4h, v4.4h\n"
-    "smlal2 v9.4s, v26.8h, v4.8h\n"
-    "ldr d4, [x3, #0xc0]\n"
-    "usubl v4.8h, v4.8b, v13.8b\n"
-    "smlal v15.4s, v27.4h, v0.4h\n"
-    "smlal2 v20.4s, v27.8h, v0.8h\n"
-    "ldr d27, [x22, x10]\n"
-    "usubl v27.8h, v27.8b, v7.8b\n"
-    "smlal v18.4s, v23.4h, v0.4h\n"
-    "smlal2 v5.4s, v23.8h, v0.8h\n"
-    "smlal v11.4s, v25.4h, v0.4h\n"
-    "smlal2 v8.4s, v25.8h, v0.8h\n"
-    "ldr d25, [x20, x10]\n"
-    "usubl v25.8h, v25.8b, v7.8b\n"
-    "smlal v10.4s, v24.4h, v0.4h\n"
-    "smlal2 v9.4s, v24.8h, v0.8h\n"
-    "smlal v15.4s, v23.4h, v1.4h\n"
-    "smlal2 v20.4s, v23.8h, v1.8h\n"
-    "smlal v18.4s, v31.4h, v1.4h\n"
-    "smlal2 v5.4s, v31.8h, v1.8h\n"
-    "smlal v11.4s, v24.4h, v1.4h\n"
-    "smlal2 v8.4s, v24.8h, v1.8h\n"
-    "ldr d24, [x13, x10]\n"
-    "usubl v24.8h, v24.8b, v7.8b\n"
-    "smlal v10.4s, v27.4h, v1.4h\n"
-    "smlal2 v9.4s, v27.8h, v1.8h\n"
-    "smlal v15.4s, v31.4h, v2.4h\n"
-    "smlal2 v20.4s, v31.8h, v2.8h\n"
-    "smlal v18.4s, v30.4h, v2.4h\n"
-    "smlal2 v5.4s, v30.8h, v2.8h\n"
-    "smlal v11.4s, v27.4h, v2.4h\n"
-    "smlal2 v8.4s, v27.8h, v2.8h\n"
-    "ldr d27, [x21, x10]\n"
-    "add x10, x10, #0x8\n"
-    "smlal v10.4s, v25.4h, v2.4h\n"
-    "usubl v27.8h, v27.8b, v7.8b\n"
-    "smlal2 v9.4s, v25.8h, v2.8h\n"
-    "smlal v15.4s, v30.4h, v3.4h\n"
-    "smlal2 v20.4s, v30.8h, v3.8h\n"
-    "smlal v18.4s, v28.4h, v3.4h\n"
-    "smlal2 v5.4s, v28.8h, v3.8h\n"
-    "smlal v11.4s, v25.4h, v3.4h\n"
-    "smlal2 v8.4s, v25.8h, v3.8h\n"
-    "smlal v10.4s, v24.4h, v3.4h\n"
-    "smlal2 v9.4s, v24.8h, v3.8h\n"
-    "smlal v15.4s, v28.4h, v4.4h\n"
-    "smlal2 v20.4s, v28.8h, v4.8h\n"
-    "smlal v18.4s, v26.4h, v4.4h\n"
+    "ldr x6, [x20, #0xd0]\n"
+    "ldr x27, [x20, #0xd8]\n"
+    "smlal2 v16.4s, v24.8h, v4.8h\n"
+    "smlal2 v5.4s, v30.8h, v3.8h\n"
+    "ldr d24, [x5, x24]\n"
+    "usubl v24.8h, v24.8b, v9.8b\n"
+    "smlal v17.4s, v27.4h, v4.4h\n"
+    "smlal v10.4s, v30.4h, v4.4h\n"
+    "ldr d3, [x23, #0x40]\n"
+    "usubl v3.8h, v3.8b, v14.8b\n"
+    "smlal v6.4s, v26.4h, v4.4h\n"
+    "smlal2 v8.4s, v27.8h, v4.8h\n"
+    "ldr d27, [x11, x24]\n"
+    "usubl v27.8h, v27.8b, v9.8b\n"
+    "smlal2 v7.4s, v30.8h, v4.8h\n"
+    "smlal v15.4s, v29.4h, v0.4h\n"
+    "ldr x11, [x20, #0xe0]\n"
+    "ldr x17, [x20, #0xe8]\n"
+    "smlal2 v16.4s, v29.8h, v0.8h\n"
     "smlal2 v5.4s, v26.8h, v4.8h\n"
-    "smlal v11.4s, v24.4h, v4.4h\n"
-    "smlal2 v8.4s, v24.8h, v4.8h\n"
-    "smlal v10.4s, v27.4h, v4.4h\n"
-    "smlal2 v9.4s, v27.8h, v4.8h\n"
-    "sqrdmulh v15.4s, v15.4s, v6.4s\n"
-    "sqrdmulh v20.4s, v20.4s, v17.4s\n"
-    "sqrdmulh v18.4s, v18.4s, v6.4s\n"
-    "sqrdmulh v5.4s, v5.4s, v17.4s\n"
-    "and v1.16b, v15.16b, v21.16b\n"
-    "sshr v1.4s, v1.4s, #0x1f\n"
-    "and v29.16b, v20.16b, v14.16b\n"
-    "and v3.16b, v18.16b, v21.16b\n"
-    "sshr v29.4s, v29.4s, #0x1f\n"
-    "and v2.16b, v5.16b, v14.16b\n"
-    "sqrdmulh v11.4s, v11.4s, v6.4s\n"
-    "sshr v3.4s, v3.4s, #0x1f\n"
-    "sqrdmulh v8.4s, v8.4s, v17.4s\n"
+    "ldr d4, [x23, #0x48]\n"
+    "usubl v4.8h, v4.8b, v14.8b\n"
+    "smlal v17.4s, v28.4h, v0.4h\n"
+    "smlal v10.4s, v22.4h, v0.4h\n"
+    "ldr x5, [x20, #0xf0]\n"
+    "ldr x25, [x20, #0xf8]\n"
+    "smlal v6.4s, v25.4h, v0.4h\n"
+    "smlal2 v8.4s, v28.8h, v0.8h\n"
+    "ldr q12, [x10, #0x0]\n"
+    "ldr q19, [x1, #0x0]\n"
+    "smlal2 v7.4s, v22.8h, v0.8h\n"
+    "smlal v15.4s, v28.4h, v1.4h\n"
+    "ldr q20, [x10, #0x10]\n"
+    "ldr q29, [x1, #0x10]\n"
+    "smlal2 v16.4s, v28.8h, v1.8h\n"
+    "smlal2 v5.4s, v25.8h, v0.8h\n"
+    "ldr d28, [x26, x24]\n"
+    "ldr d0, [x23, #0x50]\n"
+    "smlal v17.4s, v23.4h, v1.4h\n"
+    "smlal v10.4s, v25.4h, v1.4h\n"
+    "usubl v28.8h, v28.8b, v9.8b\n"
+    "ldr x26, [x20, #0x100]\n"
+    "smlal v6.4s, v24.4h, v1.4h\n"
+    "smlal2 v8.4s, v23.8h, v1.8h\n"
+    "usubl v0.8h, v0.8b, v14.8b\n"
+    "tst x0, #0x7\n"
+    "smlal2 v7.4s, v25.8h, v1.8h\n"
+    "smlal v15.4s, v23.4h, v2.4h\n"
+    "add x10, x10, #0x20\n"
+    "add x1, x1, #0x20\n"
+    "smlal2 v16.4s, v23.8h, v2.8h\n"
+    "ldr d23, [x12, x24]\n"
+    "smlal2 v5.4s, v24.8h, v1.8h\n"
+    "usubl v23.8h, v23.8b, v9.8b\n"
+    "smlal v17.4s, v31.4h, v2.4h\n"
+    "smlal v10.4s, v24.4h, v2.4h\n"
+    "ldr d1, [x23, #0x58]\n"
+    "usubl v1.8h, v1.8b, v14.8b\n"
+    "smlal v6.4s, v27.4h, v2.4h\n"
+    "smlal2 v8.4s, v31.8h, v2.8h\n"
+    "ldr x12, [x20, #0x108]\n"
+    "smlal2 v7.4s, v24.8h, v2.8h\n"
+    "smlal v15.4s, v31.4h, v3.4h\n"
+    "smlal2 v16.4s, v31.8h, v3.8h\n"
+    "smlal2 v5.4s, v27.8h, v2.8h\n"
+    "ldr d31, [x14, x24]\n"
+    "usubl v31.8h, v31.8b, v9.8b\n"
+    "smlal v17.4s, v30.4h, v3.4h\n"
+    "smlal v10.4s, v27.4h, v3.4h\n"
+    "ldr d2, [x23, #0x60]\n"
+    "usubl v2.8h, v2.8b, v14.8b\n"
+    "smlal v6.4s, v23.4h, v3.4h\n"
+    "smlal2 v8.4s, v30.8h, v3.8h\n"
+    "ldr x14, [x20, #0x110]\n"
+    "smlal2 v7.4s, v27.8h, v3.8h\n"
+    "smlal v15.4s, v30.4h, v4.4h\n"
+    "smlal2 v16.4s, v30.8h, v4.8h\n"
+    "ldr d30, [x15, x24]\n"
+    "smlal2 v5.4s, v23.8h, v3.8h\n"
+    "usubl v30.8h, v30.8b, v9.8b\n"
+    "smlal v17.4s, v26.4h, v4.4h\n"
+    "smlal v10.4s, v23.4h, v4.4h\n"
+    "ldr d3, [x23, #0x68]\n"
+    "usubl v3.8h, v3.8b, v14.8b\n"
+    "smlal v6.4s, v28.4h, v4.4h\n"
+    "smlal2 v8.4s, v26.8h, v4.8h\n"
+    "ldr d26, [x21, x24]\n"
+    "usubl v26.8h, v26.8b, v9.8b\n"
+    "smlal2 v7.4s, v23.8h, v4.8h\n"
+    "smlal v15.4s, v22.4h, v0.4h\n"
+    "ldr x21, [x20, #0x118]\n"
+    "smlal2 v16.4s, v22.8h, v0.8h\n"
+    "smlal2 v5.4s, v28.8h, v4.8h\n"
+    "ldr d4, [x23, #0x70]\n"
+    "ldr d22, [x9, x24]\n"
+    "smlal v17.4s, v25.4h, v0.4h\n"
+    "smlal v10.4s, v31.4h, v0.4h\n"
+    "usubl v4.8h, v4.8b, v14.8b\n"
+    "smlal v6.4s, v30.4h, v0.4h\n"
+    "smlal2 v8.4s, v25.8h, v0.8h\n"
+    "usubl v22.8h, v22.8b, v9.8b\n"
+    "smlal2 v7.4s, v31.8h, v0.8h\n"
+    "smlal v15.4s, v25.4h, v1.4h\n"
+    "smlal2 v16.4s, v25.8h, v1.8h\n"
+    "ldr d25, [x2, x24]\n"
+    "smlal2 v5.4s, v30.8h, v0.8h\n"
+    "usubl v25.8h, v25.8b, v9.8b\n"
+    "smlal v17.4s, v24.4h, v1.4h\n"
+    "smlal v10.4s, v30.4h, v1.4h\n"
+    "ldr d0, [x23, #0x78]\n"
+    "usubl v0.8h, v0.8b, v14.8b\n"
+    "smlal v6.4s, v26.4h, v1.4h\n"
+    "smlal2 v8.4s, v24.8h, v1.8h\n"
+    "smlal2 v7.4s, v30.8h, v1.8h\n"
+    "smlal v15.4s, v24.4h, v2.4h\n"
+    "smlal2 v16.4s, v24.8h, v2.8h\n"
+    "ldr d24, [x13, x24]\n"
+    "smlal2 v5.4s, v26.8h, v1.8h\n"
+    "usubl v24.8h, v24.8b, v9.8b\n"
+    "smlal v17.4s, v27.4h, v2.4h\n"
+    "smlal v10.4s, v26.4h, v2.4h\n"
+    "ldr d1, [x23, #0x80]\n"
+    "usubl v1.8h, v1.8b, v14.8b\n"
+    "smlal v6.4s, v25.4h, v2.4h\n"
+    "smlal2 v8.4s, v27.8h, v2.8h\n"
+    "smlal2 v7.4s, v26.8h, v2.8h\n"
+    "smlal v15.4s, v27.4h, v3.4h\n"
+    "smlal2 v16.4s, v27.8h, v3.8h\n"
+    "smlal2 v5.4s, v25.8h, v2.8h\n"
+    "ldr d27, [x19, x24]\n"
+    "usubl v27.8h, v27.8b, v9.8b\n"
+    "smlal v17.4s, v23.4h, v3.4h\n"
+    "smlal v10.4s, v25.4h, v3.4h\n"
+    "ldr d2, [x23, #0x88]\n"
+    "usubl v2.8h, v2.8b, v14.8b\n"
+    "smlal v6.4s, v24.4h, v3.4h\n"
+    "smlal2 v8.4s, v23.8h, v3.8h\n"
+    "smlal2 v7.4s, v25.8h, v3.8h\n"
+    "smlal v15.4s, v23.4h, v4.4h\n"
+    "smlal2 v16.4s, v23.8h, v4.8h\n"
+    "ldr d23, [x28, x24]\n"
+    "smlal2 v5.4s, v24.8h, v3.8h\n"
+    "usubl v23.8h, v23.8b, v9.8b\n"
+    "smlal v17.4s, v28.4h, v4.4h\n"
+    "smlal v10.4s, v24.4h, v4.4h\n"
+    "ldr d3, [x23, #0x90]\n"
+    "usubl v3.8h, v3.8b, v14.8b\n"
+    "smlal v6.4s, v22.4h, v4.4h\n"
+    "smlal2 v8.4s, v28.8h, v4.8h\n"
+    "ldr d28, [x11, x24]\n"
+    "usubl v28.8h, v28.8b, v9.8b\n"
+    "smlal2 v7.4s, v24.8h, v4.8h\n"
+    "smlal v15.4s, v31.4h, v0.4h\n"
+    "smlal2 v16.4s, v31.8h, v0.8h\n"
+    "ldr d31, [x6, x24]\n"
+    "smlal2 v5.4s, v22.8h, v4.8h\n"
+    "usubl v31.8h, v31.8b, v9.8b\n"
+    "smlal v17.4s, v30.4h, v0.4h\n"
+    "smlal v10.4s, v27.4h, v0.4h\n"
+    "ldr d4, [x23, #0x98]\n"
+    "usubl v4.8h, v4.8b, v14.8b\n"
+    "smlal v6.4s, v23.4h, v0.4h\n"
+    "smlal2 v8.4s, v30.8h, v0.8h\n"
+    "smlal2 v7.4s, v27.8h, v0.8h\n"
+    "smlal v15.4s, v30.4h, v1.4h\n"
+    "smlal2 v16.4s, v30.8h, v1.8h\n"
+    "ldr d30, [x27, x24]\n"
+    "smlal2 v5.4s, v23.8h, v0.8h\n"
+    "usubl v30.8h, v30.8b, v9.8b\n"
+    "smlal v17.4s, v26.4h, v1.4h\n"
+    "smlal v10.4s, v23.4h, v1.4h\n"
+    "ldr d0, [x23, #0xa0]\n"
+    "usubl v0.8h, v0.8b, v14.8b\n"
+    "smlal v6.4s, v31.4h, v1.4h\n"
+    "smlal2 v8.4s, v26.8h, v1.8h\n"
+    "smlal2 v7.4s, v23.8h, v1.8h\n"
+    "smlal v15.4s, v26.4h, v2.4h\n"
+    "smlal2 v16.4s, v26.8h, v2.8h\n"
+    "smlal2 v5.4s, v31.8h, v1.8h\n"
+    "ldr d26, [x17, x24]\n"
+    "usubl v26.8h, v26.8b, v9.8b\n"
+    "smlal v17.4s, v25.4h, v2.4h\n"
+    "smlal v10.4s, v31.4h, v2.4h\n"
+    "ldr d1, [x23, #0xa8]\n"
+    "usubl v1.8h, v1.8b, v14.8b\n"
+    "smlal v6.4s, v30.4h, v2.4h\n"
+    "smlal2 v8.4s, v25.8h, v2.8h\n"
+    "smlal2 v7.4s, v31.8h, v2.8h\n"
+    "smlal v15.4s, v25.4h, v3.4h\n"
+    "smlal2 v16.4s, v25.8h, v3.8h\n"
+    "smlal2 v5.4s, v30.8h, v2.8h\n"
+    "ldr d25, [x5, x24]\n"
+    "usubl v25.8h, v25.8b, v9.8b\n"
+    "smlal v17.4s, v24.4h, v3.4h\n"
+    "smlal v10.4s, v30.4h, v3.4h\n"
+    "ldr d2, [x23, #0xb0]\n"
+    "usubl v2.8h, v2.8b, v14.8b\n"
+    "smlal v6.4s, v28.4h, v3.4h\n"
+    "smlal2 v8.4s, v24.8h, v3.8h\n"
+    "smlal2 v7.4s, v30.8h, v3.8h\n"
+    "smlal v15.4s, v24.4h, v4.4h\n"
+    "smlal2 v16.4s, v24.8h, v4.8h\n"
+    "ldr d24, [x25, x24]\n"
+    "smlal2 v5.4s, v28.8h, v3.8h\n"
+    "usubl v24.8h, v24.8b, v9.8b\n"
+    "smlal v17.4s, v22.4h, v4.4h\n"
+    "smlal v10.4s, v28.4h, v4.4h\n"
+    "ldr d3, [x23, #0xb8]\n"
+    "usubl v3.8h, v3.8b, v14.8b\n"
+    "smlal v6.4s, v26.4h, v4.4h\n"
+    "smlal2 v7.4s, v28.8h, v4.8h\n"
+    "smlal v15.4s, v27.4h, v0.4h\n"
+    "smlal2 v16.4s, v27.8h, v0.8h\n"
+    "ldr d27, [x26, x24]\n"
+    "usubl v27.8h, v27.8b, v9.8b\n"
+    "smlal2 v8.4s, v22.8h, v4.8h\n"
+    "smlal2 v5.4s, v26.8h, v4.8h\n"
+    "ldr d4, [x23, #0xc0]\n"
+    "usubl v4.8h, v4.8b, v14.8b\n"
+    "smlal v17.4s, v23.4h, v0.4h\n"
+    "smlal v10.4s, v25.4h, v0.4h\n"
+    "smlal v6.4s, v24.4h, v0.4h\n"
+    "smlal2 v7.4s, v25.8h, v0.8h\n"
+    "ldr d25, [x12, x24]\n"
+    "usubl v25.8h, v25.8b, v9.8b\n"
+    "smlal2 v8.4s, v23.8h, v0.8h\n"
+    "smlal2 v5.4s, v24.8h, v0.8h\n"
+    "smlal v15.4s, v23.4h, v1.4h\n"
+    "smlal v17.4s, v31.4h, v1.4h\n"
+    "smlal v10.4s, v24.4h, v1.4h\n"
+    "smlal v6.4s, v27.4h, v1.4h\n"
+    "smlal2 v7.4s, v24.8h, v1.8h\n"
+    "ldr d24, [x14, x24]\n"
+    "smlal2 v16.4s, v23.8h, v1.8h\n"
+    "usubl v24.8h, v24.8b, v9.8b\n"
+    "smlal2 v8.4s, v31.8h, v1.8h\n"
+    "smlal2 v5.4s, v27.8h, v1.8h\n"
+    "smlal v15.4s, v31.4h, v2.4h\n"
+    "smlal v17.4s, v30.4h, v2.4h\n"
+    "smlal v10.4s, v27.4h, v2.4h\n"
+    "smlal v6.4s, v25.4h, v2.4h\n"
+    "smlal2 v7.4s, v27.8h, v2.8h\n"
+    "ldr d27, [x21, x24]\n"
+    "smlal2 v16.4s, v31.8h, v2.8h\n"
+    "usubl v27.8h, v27.8b, v9.8b\n"
+    "smlal2 v8.4s, v30.8h, v2.8h\n"
+    "smlal2 v5.4s, v25.8h, v2.8h\n"
+    "add x24, x24, #0x8\n"
+    "smlal v15.4s, v30.4h, v3.4h\n"
+    "smlal v17.4s, v28.4h, v3.4h\n"
+    "smlal v10.4s, v25.4h, v3.4h\n"
+    "smlal v6.4s, v24.4h, v3.4h\n"
+    "smlal2 v16.4s, v30.8h, v3.8h\n"
+    "smlal2 v8.4s, v28.8h, v3.8h\n"
+    "smlal2 v7.4s, v25.8h, v3.8h\n"
+    "smlal2 v5.4s, v24.8h, v3.8h\n"
+    "smlal v15.4s, v28.4h, v4.4h\n"
+    "smlal v17.4s, v26.4h, v4.4h\n"
+    "sqdmulh v15.4s, v15.4s, v12.4s\n"
+    "smlal v10.4s, v24.4h, v4.4h\n"
+    "smlal v6.4s, v27.4h, v4.4h\n"
+    "sqdmulh v17.4s, v17.4s, v12.4s\n"
+    "smlal2 v16.4s, v28.8h, v4.8h\n"
+    "smlal2 v8.4s, v26.8h, v4.8h\n"
+    "sqdmulh v10.4s, v10.4s, v12.4s\n"
+    "smlal2 v7.4s, v24.8h, v4.8h\n"
+    "smlal2 v5.4s, v27.8h, v4.8h\n"
+    "sqdmulh v6.4s, v6.4s, v12.4s\n"
+    "and v23.16b, v15.16b, v19.16b\n"
+    "sqdmulh v16.4s, v16.4s, v20.4s\n"
+    "and v22.16b, v17.16b, v19.16b\n"
+    "sqdmulh v8.4s, v8.4s, v20.4s\n"
+    "and v21.16b, v10.16b, v19.16b\n"
+    "sqdmulh v7.4s, v7.4s, v20.4s\n"
+    "and v26.16b, v6.16b, v19.16b\n"
+    "sqdmulh v5.4s, v5.4s, v20.4s\n"
+    "sshr v23.4s, v23.4s, #0x1f\n"
+    "and v4.16b, v16.16b, v29.16b\n"
+    "sshr v22.4s, v22.4s, #0x1f\n"
+    "and v2.16b, v8.16b, v29.16b\n"
+    "sshr v21.4s, v21.4s, #0x1f\n"
+    "and v3.16b, v7.16b, v29.16b\n"
+    "sshr v26.4s, v26.4s, #0x1f\n"
+    "and v25.16b, v5.16b, v29.16b\n"
+    "sqadd v15.4s, v15.4s, v23.4s\n"
+    "sshr v4.4s, v4.4s, #0x1f\n"
+    "sqadd v17.4s, v17.4s, v22.4s\n"
     "sshr v2.4s, v2.4s, #0x1f\n"
-    "sqadd v15.4s, v15.4s, v1.4s\n"
-    "sqrdmulh v10.4s, v10.4s, v6.4s\n"
-    "and v0.16b, v11.16b, v21.16b\n"
-    "sshr v0.4s, v0.4s, #0x1f\n"
-    "srshl v15.4s, v15.4s, v21.4s\n"
-    "sqadd v20.4s, v20.4s, v29.4s\n"
-    "sqadd v18.4s, v18.4s, v3.4s\n"
-    "sqadd v5.4s, v5.4s, v2.4s\n"
-    "and v27.16b, v8.16b, v14.16b\n"
-    "sshr v27.4s, v27.4s, #0x1f\n"
-    "add v15.4s, v15.4s, v19.4s\n"
-    "srshl v20.4s, v20.4s, v14.4s\n"
-    "srshl v18.4s, v18.4s, v21.4s\n"
-    "srshl v5.4s, v5.4s, v14.4s\n"
-    "smin v15.4s, v15.4s, v12.4s\n"
-    "add v20.4s, v20.4s, v19.4s\n"
-    "add v18.4s, v18.4s, v19.4s\n"
-    "smax v15.4s, v15.4s, v16.4s\n"
-    "smin v20.4s, v20.4s, v12.4s\n"
-    "smin v18.4s, v18.4s, v12.4s\n"
-    "add v5.4s, v5.4s, v19.4s\n"
-    "smax v20.4s, v20.4s, v16.4s\n"
-    "smax v18.4s, v18.4s, v16.4s\n"
-    "smin v5.4s, v5.4s, v12.4s\n"
-    "uzp1 v15.16b, v15.16b, v20.16b\n"
-    "sqadd v11.4s, v11.4s, v0.4s\n"
+    "sqadd v10.4s, v10.4s, v21.4s\n"
+    "sshr v3.4s, v3.4s, #0x1f\n"
+    "sqadd v6.4s, v6.4s, v26.4s\n"
+    "sshr v25.4s, v25.4s, #0x1f\n"
+    "srshl v15.4s, v15.4s, v19.4s\n"
+    "sqadd v16.4s, v16.4s, v4.4s\n"
+    "srshl v17.4s, v17.4s, v19.4s\n"
+    "sqadd v8.4s, v8.4s, v2.4s\n"
+    "srshl v10.4s, v10.4s, v19.4s\n"
+    "sqadd v7.4s, v7.4s, v3.4s\n"
+    "srshl v6.4s, v6.4s, v19.4s\n"
+    "sqadd v5.4s, v5.4s, v25.4s\n"
+    "srshl v16.4s, v16.4s, v29.4s\n"
+    "sqxtn v15.4h, v15.4s\n"
+    "srshl v8.4s, v8.4s, v29.4s\n"
+    "sqxtn v17.4h, v17.4s\n"
+    "srshl v7.4s, v7.4s, v29.4s\n"
+    "sqxtn v10.4h, v10.4s\n"
+    "srshl v5.4s, v5.4s, v29.4s\n"
+    "sqxtn v6.4h, v6.4s\n"
+    "sqxtn2 v15.8h, v16.4s\n"
+    "sqxtn2 v17.8h, v8.4s\n"
+    "sqxtn2 v10.8h, v7.4s\n"
+    "sqxtn2 v6.8h, v5.4s\n"
+    "sqadd v15.8h, v15.8h, v18.8h\n"
+    "sqadd v17.8h, v17.8h, v18.8h\n"
+    "sqadd v10.8h, v10.8h, v18.8h\n"
+    "sqadd v6.8h, v6.8h, v18.8h\n"
+    "smax v15.8h, v15.8h, v11.8h\n"
+    "smax v17.8h, v17.8h, v11.8h\n"
+    "smax v10.8h, v10.8h, v11.8h\n"
+    "smax v6.8h, v6.8h, v11.8h\n"
+    "smin v15.8h, v15.8h, v13.8h\n"
+    "smin v17.8h, v17.8h, v13.8h\n"
+    "smin v10.8h, v10.8h, v13.8h\n"
+    "smin v6.8h, v6.8h, v13.8h\n"
     "uzp1 v15.16b, v15.16b, v15.16b\n"
-    "str d15, [x17, x1]\n"
-    "smax v5.4s, v5.4s, v16.4s\n"
-    "sqadd v8.4s, v8.4s, v27.4s\n"
-    "srshl v11.4s, v11.4s, v21.4s\n"
-    "and v30.16b, v10.16b, v21.16b\n"
-    "sshr v30.4s, v30.4s, #0x1f\n"
-    "uzp1 v18.16b, v18.16b, v5.16b\n"
-    "add v11.4s, v11.4s, v19.4s\n"
-    "srshl v8.4s, v8.4s, v14.4s\n"
-    "uzp1 v18.16b, v18.16b, v18.16b\n"
-    "str d18, [x16, x1]\n"
-    "smin v11.4s, v11.4s, v12.4s\n"
-    "sqrdmulh v9.4s, v9.4s, v17.4s\n"
-    "add v8.4s, v8.4s, v19.4s\n"
-    "sqadd v10.4s, v10.4s, v30.4s\n"
-    "smax v11.4s, v11.4s, v16.4s\n"
-    "smin v8.4s, v8.4s, v12.4s\n"
-    "and v6.16b, v9.16b, v14.16b\n"
-    "sshr v6.4s, v6.4s, #0x1f\n"
-    "smax v8.4s, v8.4s, v16.4s\n"
-    "srshl v10.4s, v10.4s, v21.4s\n"
-    "uzp1 v11.16b, v11.16b, v8.16b\n"
-    "add v10.4s, v10.4s, v19.4s\n"
-    "uzp1 v11.16b, v11.16b, v11.16b\n"
-    "str d11, [x6, x1]\n"
-    "smin v10.4s, v10.4s, v12.4s\n"
-    "sqadd v9.4s, v9.4s, v6.4s\n"
-    "smax v10.4s, v10.4s, v16.4s\n"
-    "srshl v9.4s, v9.4s, v14.4s\n"
-    "add v9.4s, v9.4s, v19.4s\n"
-    "smin v9.4s, v9.4s, v12.4s\n"
-    "smax v9.4s, v9.4s, v16.4s\n"
-    "uzp1 v10.16b, v10.16b, v9.16b\n"
+    "uzp1 v17.16b, v17.16b, v17.16b\n"
+    "str d15, [x16, x22]\n"
     "uzp1 v10.16b, v10.16b, v10.16b\n"
-    "str d10, [x8, x1]\n"
-    "add x1, x1, #0x8\n"
+    "uzp1 v6.16b, v6.16b, v6.16b\n"
+    "str d17, [x8, x22]\n"
+    "str d10, [x4, x22]\n"
+    "str d6, [x7, x22]\n"
+    "add x22, x22, #0x8\n"
     "beq 124f\n"
-    "add x3, x3, #0xc8\n"
+    "add x23, x23, #0xc8\n"
     "3:"  // Oddments
-    "ldr x12, [%x[params], %[offsetof_Params_bias]]\n"
-    "tbz x4, #2, 5f\n"
-    "ld1 { v15.4s }, [x12], #0x10\n"
-    "tbz x4, #1, 4f\n"
-    "ld1 { v20.d }[0], [x12], #0x8\n"
-    "tbz x4, #0, 7f\n"
-    "ld1 { v20.s }[2], [x12]\n"
+    "ldr x19, [%x[params], %[offsetof_Params_bias]]\n"
+    "tbz x0, #2, 5f\n"
+    "ld1 { v15.4s }, [x19], #0x10\n"
+    "tbz x0, #1, 4f\n"
+    "ld1 { v16.d }[0], [x19], #0x8\n"
+    "tbz x0, #0, 7f\n"
+    "ld1 { v16.s }[2], [x19]\n"
     "b 7f\n"
     "4:"  // Oddments: Load bias: Bit 2: Bit 1: Unset
-    "tbz x4, #0, 7f\n"
-    "ld1 { v20.s }[0], [x12]\n"
+    "tbz x0, #0, 7f\n"
+    "ld1 { v16.s }[0], [x19]\n"
     "b 7f\n"
     "5:"  // Oddments: Load bias: Bit 2: Unset
-    "tbz x4, #1, 6f\n"
-    "ld1 { v15.d }[0], [x12], #0x8\n"
-    "tbz x4, #0, 7f\n"
-    "ld1 { v15.s }[2], [x12]\n"
+    "tbz x0, #1, 6f\n"
+    "ld1 { v15.d }[0], [x19], #0x8\n"
+    "tbz x0, #0, 7f\n"
+    "ld1 { v15.s }[2], [x19]\n"
     "b 7f\n"
     "6:"  // Oddments: Load bias: Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 7f\n"
-    "ld1 { v15.s }[0], [x12]\n"
+    "tbz x0, #0, 7f\n"
+    "ld1 { v15.s }[0], [x19]\n"
     "7:"  // Oddments: Load bias: Bit 2: End
-    "mov v18.16b, v15.16b\n"
-    "ldr d0, [x3, #0x0]\n"
-    "mov v5.16b, v20.16b\n"
-    "ldr d1, [x3, #0x8]\n"
-    "mov v11.16b, v15.16b\n"
-    "ldr d2, [x3, #0x10]\n"
-    "mov v8.16b, v20.16b\n"
-    "ldr d3, [x3, #0x18]\n"
+    "ldr d0, [x23, #0x0]\n"
+    "ldr d1, [x23, #0x8]\n"
+    "mov v17.16b, v15.16b\n"
+    "mov v8.16b, v16.16b\n"
+    "ldr d2, [x23, #0x10]\n"
+    "ldr d3, [x23, #0x18]\n"
     "mov v10.16b, v15.16b\n"
-    "ldr d4, [x3, #0x20]\n"
-    "usubl v0.8h, v0.8b, v13.8b\n"
-    "mov v9.16b, v20.16b\n"
-    "ldp x28, x27, [x25, #0x0]\n"
-    "usubl v1.8h, v1.8b, v13.8b\n"
-    "ldp x26, x13, [x25, #0x10]\n"
-    "usubl v2.8h, v2.8b, v13.8b\n"
-    "usubl v3.8h, v3.8b, v13.8b\n"
-    "ldp x24, x23, [x25, #0x20]\n"
-    "usubl v4.8h, v4.8b, v13.8b\n"
-    "ldp x22, x21, [x25, #0x30]\n"
-    "ldp x20, x0, [x25, #0x40]\n"
-    "add x28, x28, x10\n"
-    "add x27, x27, x10\n"
-    "add x26, x26, x10\n"
-    "add x13, x13, x10\n"
-    "add x24, x24, x10\n"
-    "add x23, x23, x10\n"
-    "add x22, x22, x10\n"
-    "add x21, x21, x10\n"
-    "add x20, x20, x10\n"
-    "add x0, x0, x10\n"
-    "tbz x4, #2, 9f\n"
+    "mov v7.16b, v16.16b\n"
+    "ldr d4, [x23, #0x20]\n"
+    "ldp x28, x6, [x20, #0x0]\n"
+    "mov v6.16b, v15.16b\n"
+    "mov v5.16b, v16.16b\n"
+    "ldp x26, x25, [x20, #0x10]\n"
+    "ldp x5, x2, [x20, #0x20]\n"
+    "usubl v0.8h, v0.8b, v14.8b\n"
+    "usubl v1.8h, v1.8b, v14.8b\n"
+    "ldp x27, x21, [x20, #0x30]\n"
+    "ldp x12, x19, [x20, #0x40]\n"
+    "usubl v2.8h, v2.8b, v14.8b\n"
+    "usubl v3.8h, v3.8b, v14.8b\n"
+    "usubl v4.8h, v4.8b, v14.8b\n"
+    "add x28, x28, x24\n"
+    "add x6, x6, x24\n"
+    "add x26, x26, x24\n"
+    "add x25, x25, x24\n"
+    "add x5, x5, x24\n"
+    "add x2, x2, x24\n"
+    "add x27, x27, x24\n"
+    "add x21, x21, x24\n"
+    "add x12, x12, x24\n"
+    "add x19, x19, x24\n"
+    "tbz x0, #2, 9f\n"
     "ld1 { v31.s }[0], [x28], #0x4\n"
-    "ld1 { v30.s }[0], [x27], #0x4\n"
+    "ld1 { v30.s }[0], [x6], #0x4\n"
     "ld1 { v29.s }[0], [x26], #0x4\n"
-    "ld1 { v28.s }[0], [x13], #0x4\n"
-    "ld1 { v27.s }[0], [x24], #0x4\n"
-    "ld1 { v23.s }[0], [x23], #0x4\n"
-    "ld1 { v25.s }[0], [x22], #0x4\n"
+    "ld1 { v28.s }[0], [x25], #0x4\n"
+    "ld1 { v27.s }[0], [x5], #0x4\n"
+    "ld1 { v23.s }[0], [x2], #0x4\n"
+    "ld1 { v25.s }[0], [x27], #0x4\n"
     "ld1 { v24.s }[0], [x21], #0x4\n"
-    "ld1 { v26.s }[0], [x20], #0x4\n"
-    "ld1 { v22.s }[0], [x0], #0x4\n"
-    "tbz x4, #1, 8f\n"
+    "ld1 { v26.s }[0], [x12], #0x4\n"
+    "ld1 { v22.s }[0], [x19], #0x4\n"
+    "tbz x0, #1, 8f\n"
     "ld1 { v31.h }[2], [x28], #0x2\n"
-    "ld1 { v30.h }[2], [x27], #0x2\n"
+    "ld1 { v30.h }[2], [x6], #0x2\n"
     "ld1 { v29.h }[2], [x26], #0x2\n"
-    "ld1 { v28.h }[2], [x13], #0x2\n"
-    "ld1 { v27.h }[2], [x24], #0x2\n"
-    "ld1 { v23.h }[2], [x23], #0x2\n"
-    "ld1 { v25.h }[2], [x22], #0x2\n"
+    "ld1 { v28.h }[2], [x25], #0x2\n"
+    "ld1 { v27.h }[2], [x5], #0x2\n"
+    "ld1 { v23.h }[2], [x2], #0x2\n"
+    "ld1 { v25.h }[2], [x27], #0x2\n"
     "ld1 { v24.h }[2], [x21], #0x2\n"
-    "ld1 { v26.h }[2], [x20], #0x2\n"
-    "ld1 { v22.h }[2], [x0], #0x2\n"
-    "tbz x4, #0, 11f\n"
+    "ld1 { v26.h }[2], [x12], #0x2\n"
+    "ld1 { v22.h }[2], [x19], #0x2\n"
+    "tbz x0, #0, 11f\n"
     "ld1 { v31.b }[6], [x28]\n"
-    "ld1 { v30.b }[6], [x27]\n"
+    "ld1 { v30.b }[6], [x6]\n"
     "ld1 { v29.b }[6], [x26]\n"
-    "ld1 { v28.b }[6], [x13]\n"
-    "ld1 { v27.b }[6], [x24]\n"
-    "ld1 { v23.b }[6], [x23]\n"
-    "ld1 { v25.b }[6], [x22]\n"
+    "ld1 { v28.b }[6], [x25]\n"
+    "ld1 { v27.b }[6], [x5]\n"
+    "ld1 { v23.b }[6], [x2]\n"
+    "ld1 { v25.b }[6], [x27]\n"
     "ld1 { v24.b }[6], [x21]\n"
-    "ld1 { v26.b }[6], [x20]\n"
-    "ld1 { v22.b }[6], [x0]\n"
+    "ld1 { v26.b }[6], [x12]\n"
+    "ld1 { v22.b }[6], [x19]\n"
     "b 11f\n"
     "8:"  // Oddments: Initial loads: Bit 2: Bit 1: Unset
-    "tbz x4, #0, 11f\n"
+    "tbz x0, #0, 11f\n"
     "ld1 { v31.b }[4], [x28]\n"
-    "ld1 { v30.b }[4], [x27]\n"
+    "ld1 { v30.b }[4], [x6]\n"
     "ld1 { v29.b }[4], [x26]\n"
-    "ld1 { v28.b }[4], [x13]\n"
-    "ld1 { v27.b }[4], [x24]\n"
-    "ld1 { v23.b }[4], [x23]\n"
-    "ld1 { v25.b }[4], [x22]\n"
+    "ld1 { v28.b }[4], [x25]\n"
+    "ld1 { v27.b }[4], [x5]\n"
+    "ld1 { v23.b }[4], [x2]\n"
+    "ld1 { v25.b }[4], [x27]\n"
     "ld1 { v24.b }[4], [x21]\n"
-    "ld1 { v26.b }[4], [x20]\n"
-    "ld1 { v22.b }[4], [x0]\n"
+    "ld1 { v26.b }[4], [x12]\n"
+    "ld1 { v22.b }[4], [x19]\n"
     "b 11f\n"
     "9:"  // Oddments: Initial loads: Bit 2: Unset
-    "tbz x4, #1, 10f\n"
+    "tbz x0, #1, 10f\n"
     "ld1 { v31.h }[0], [x28], #0x2\n"
-    "ld1 { v30.h }[0], [x27], #0x2\n"
+    "ld1 { v30.h }[0], [x6], #0x2\n"
     "ld1 { v29.h }[0], [x26], #0x2\n"
-    "ld1 { v28.h }[0], [x13], #0x2\n"
-    "ld1 { v27.h }[0], [x24], #0x2\n"
-    "ld1 { v23.h }[0], [x23], #0x2\n"
-    "ld1 { v25.h }[0], [x22], #0x2\n"
+    "ld1 { v28.h }[0], [x25], #0x2\n"
+    "ld1 { v27.h }[0], [x5], #0x2\n"
+    "ld1 { v23.h }[0], [x2], #0x2\n"
+    "ld1 { v25.h }[0], [x27], #0x2\n"
     "ld1 { v24.h }[0], [x21], #0x2\n"
-    "ld1 { v26.h }[0], [x20], #0x2\n"
-    "ld1 { v22.h }[0], [x0], #0x2\n"
-    "tbz x4, #0, 11f\n"
+    "ld1 { v26.h }[0], [x12], #0x2\n"
+    "ld1 { v22.h }[0], [x19], #0x2\n"
+    "tbz x0, #0, 11f\n"
     "ld1 { v31.b }[2], [x28]\n"
-    "ld1 { v30.b }[2], [x27]\n"
+    "ld1 { v30.b }[2], [x6]\n"
     "ld1 { v29.b }[2], [x26]\n"
-    "ld1 { v28.b }[2], [x13]\n"
-    "ld1 { v27.b }[2], [x24]\n"
-    "ld1 { v23.b }[2], [x23]\n"
-    "ld1 { v25.b }[2], [x22]\n"
+    "ld1 { v28.b }[2], [x25]\n"
+    "ld1 { v27.b }[2], [x5]\n"
+    "ld1 { v23.b }[2], [x2]\n"
+    "ld1 { v25.b }[2], [x27]\n"
     "ld1 { v24.b }[2], [x21]\n"
-    "ld1 { v26.b }[2], [x20]\n"
-    "ld1 { v22.b }[2], [x0]\n"
+    "ld1 { v26.b }[2], [x12]\n"
+    "ld1 { v22.b }[2], [x19]\n"
     "b 11f\n"
     "10:"  // Oddments: Initial loads: Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 11f\n"
+    "tbz x0, #0, 11f\n"
     "ld1 { v31.b }[0], [x28]\n"
-    "ld1 { v30.b }[0], [x27]\n"
+    "ld1 { v30.b }[0], [x6]\n"
     "ld1 { v29.b }[0], [x26]\n"
-    "ld1 { v28.b }[0], [x13]\n"
-    "ld1 { v27.b }[0], [x24]\n"
-    "ld1 { v23.b }[0], [x23]\n"
-    "ld1 { v25.b }[0], [x22]\n"
+    "ld1 { v28.b }[0], [x25]\n"
+    "ld1 { v27.b }[0], [x5]\n"
+    "ld1 { v23.b }[0], [x2]\n"
+    "ld1 { v25.b }[0], [x27]\n"
     "ld1 { v24.b }[0], [x21]\n"
-    "ld1 { v26.b }[0], [x20]\n"
-    "ld1 { v22.b }[0], [x0]\n"
+    "ld1 { v26.b }[0], [x12]\n"
+    "ld1 { v22.b }[0], [x19]\n"
     "11:"  // Oddments: Initial loads: Bit 2: End
-    "ldr x20, [x25, #0x50]\n"
-    "usubl v31.8h, v31.8b, v7.8b\n"
+    "usubl v31.8h, v31.8b, v9.8b\n"
+    "usubl v30.8h, v30.8b, v9.8b\n"
     "smlal v15.4s, v31.4h, v0.4h\n"
-    "usubl v30.8h, v30.8b, v7.8b\n"
-    "smlal2 v20.4s, v31.8h, v0.8h\n"
-    "usubl v29.8h, v29.8b, v7.8b\n"
-    "usubl v28.8h, v28.8b, v7.8b\n"
-    "smlal v18.4s, v30.4h, v0.4h\n"
-    "usubl v27.8h, v27.8b, v7.8b\n"
-    "smlal2 v5.4s, v30.8h, v0.8h\n"
-    "usubl v23.8h, v23.8b, v7.8b\n"
-    "smlal v11.4s, v29.4h, v0.4h\n"
-    "usubl v25.8h, v25.8b, v7.8b\n"
-    "smlal2 v8.4s, v29.8h, v0.8h\n"
-    "usubl v24.8h, v24.8b, v7.8b\n"
-    "smlal v10.4s, v28.4h, v0.4h\n"
-    "usubl v26.8h, v26.8b, v7.8b\n"
-    "smlal2 v9.4s, v28.8h, v0.8h\n"
-    "usubl v22.8h, v22.8b, v7.8b\n"
+    "ldr x19, [x20, #0x50]\n"
+    "usubl v29.8h, v29.8b, v9.8b\n"
+    "smlal2 v16.4s, v31.8h, v0.8h\n"
+    "smlal v17.4s, v30.4h, v0.4h\n"
+    "smlal2 v8.4s, v30.8h, v0.8h\n"
+    "smlal v10.4s, v29.4h, v0.4h\n"
+    "usubl v28.8h, v28.8b, v9.8b\n"
+    "add x19, x19, x24\n"
+    "smlal2 v7.4s, v29.8h, v0.8h\n"
+    "usubl v27.8h, v27.8b, v9.8b\n"
+    "smlal v6.4s, v28.4h, v0.4h\n"
+    "smlal2 v5.4s, v28.8h, v0.8h\n"
     "smlal v15.4s, v30.4h, v1.4h\n"
-    "smlal2 v20.4s, v30.8h, v1.8h\n"
-    "add x20, x20, x10\n"
-    "smlal v18.4s, v27.4h, v1.4h\n"
-    "smlal2 v5.4s, v27.8h, v1.8h\n"
-    "smlal v11.4s, v28.4h, v1.4h\n"
-    "smlal2 v8.4s, v28.8h, v1.8h\n"
-    "smlal v10.4s, v23.4h, v1.4h\n"
-    "smlal2 v9.4s, v23.8h, v1.8h\n"
+    "usubl v23.8h, v23.8b, v9.8b\n"
+    "smlal2 v16.4s, v30.8h, v1.8h\n"
+    "smlal v17.4s, v27.4h, v1.4h\n"
+    "usubl v25.8h, v25.8b, v9.8b\n"
+    "smlal2 v8.4s, v27.8h, v1.8h\n"
+    "smlal v10.4s, v28.4h, v1.4h\n"
+    "usubl v24.8h, v24.8b, v9.8b\n"
+    "smlal2 v7.4s, v28.8h, v1.8h\n"
+    "usubl v26.8h, v26.8b, v9.8b\n"
+    "smlal v6.4s, v23.4h, v1.4h\n"
+    "usubl v22.8h, v22.8b, v9.8b\n"
+    "smlal2 v5.4s, v23.8h, v1.8h\n"
     "smlal v15.4s, v27.4h, v2.4h\n"
-    "smlal2 v20.4s, v27.8h, v2.8h\n"
-    "smlal v18.4s, v25.4h, v2.4h\n"
-    "smlal2 v5.4s, v25.8h, v2.8h\n"
-    "smlal v11.4s, v23.4h, v2.4h\n"
-    "smlal2 v8.4s, v23.8h, v2.8h\n"
-    "tbz x4, #2, 13f\n"
-    "ld1 { v31.s }[0], [x20], #0x4\n"
-    "tbz x4, #1, 12f\n"
-    "ld1 { v31.h }[2], [x20], #0x2\n"
-    "tbz x4, #0, 15f\n"
-    "ld1 { v31.b }[6], [x20]\n"
+    "smlal2 v16.4s, v27.8h, v2.8h\n"
+    "smlal v17.4s, v25.4h, v2.4h\n"
+    "smlal2 v8.4s, v25.8h, v2.8h\n"
+    "smlal v10.4s, v23.4h, v2.4h\n"
+    "smlal2 v7.4s, v23.8h, v2.8h\n"
+    "tbz x0, #2, 13f\n"
+    "ld1 { v31.s }[0], [x19], #0x4\n"
+    "tbz x0, #1, 12f\n"
+    "ld1 { v31.h }[2], [x19], #0x2\n"
+    "tbz x0, #0, 15f\n"
+    "ld1 { v31.b }[6], [x19]\n"
     "b 15f\n"
     "12:"  // Oddments: Load (1, 3): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 15f\n"
-    "ld1 { v31.b }[4], [x20]\n"
+    "tbz x0, #0, 15f\n"
+    "ld1 { v31.b }[4], [x19]\n"
     "b 15f\n"
     "13:"  // Oddments: Load (1, 3): Bit 2: Unset
-    "tbz x4, #1, 14f\n"
-    "ld1 { v31.h }[0], [x20], #0x2\n"
-    "tbz x4, #0, 15f\n"
-    "ld1 { v31.b }[2], [x20]\n"
+    "tbz x0, #1, 14f\n"
+    "ld1 { v31.h }[0], [x19], #0x2\n"
+    "tbz x0, #0, 15f\n"
+    "ld1 { v31.b }[2], [x19]\n"
     "b 15f\n"
     "14:"  // Oddments: Load (1, 3): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 15f\n"
-    "ld1 { v31.b }[0], [x20]\n"
+    "tbz x0, #0, 15f\n"
+    "ld1 { v31.b }[0], [x19]\n"
     "15:"  // Oddments: Load (1, 3): Bit 2: End
+    "usubl v31.8h, v31.8b, v9.8b\n"
+    "ldr x15, [x20, #0x58]\n"
+    "smlal v6.4s, v31.4h, v2.4h\n"
+    "smlal2 v5.4s, v31.8h, v2.8h\n"
     "smlal v15.4s, v25.4h, v3.4h\n"
-    "ldr x28, [x25, #0x58]\n"
-    "usubl v31.8h, v31.8b, v7.8b\n"
-    "smlal2 v20.4s, v25.8h, v3.8h\n"
-    "smlal v18.4s, v24.4h, v3.4h\n"
-    "add x28, x28, x10\n"
-    "smlal2 v5.4s, v24.8h, v3.8h\n"
-    "smlal v10.4s, v31.4h, v2.4h\n"
-    "smlal2 v9.4s, v31.8h, v2.8h\n"
-    "smlal v11.4s, v31.4h, v3.4h\n"
-    "smlal2 v8.4s, v31.8h, v3.8h\n"
-    "tbz x4, #2, 17f\n"
-    "ld1 { v30.s }[0], [x28], #0x4\n"
-    "tbz x4, #1, 16f\n"
-    "ld1 { v30.h }[2], [x28], #0x2\n"
-    "tbz x4, #0, 19f\n"
-    "ld1 { v30.b }[6], [x28]\n"
+    "smlal2 v16.4s, v25.8h, v3.8h\n"
+    "add x15, x15, x24\n"
+    "smlal v17.4s, v24.4h, v3.4h\n"
+    "smlal2 v8.4s, v24.8h, v3.8h\n"
+    "smlal v10.4s, v31.4h, v3.4h\n"
+    "smlal2 v7.4s, v31.8h, v3.8h\n"
+    "tbz x0, #2, 17f\n"
+    "ld1 { v30.s }[0], [x15], #0x4\n"
+    "tbz x0, #1, 16f\n"
+    "ld1 { v30.h }[2], [x15], #0x2\n"
+    "tbz x0, #0, 19f\n"
+    "ld1 { v30.b }[6], [x15]\n"
     "b 19f\n"
     "16:"  // Oddments: Load (1, 4): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 19f\n"
-    "ld1 { v30.b }[4], [x28]\n"
+    "tbz x0, #0, 19f\n"
+    "ld1 { v30.b }[4], [x15]\n"
     "b 19f\n"
     "17:"  // Oddments: Load (1, 4): Bit 2: Unset
-    "tbz x4, #1, 18f\n"
-    "ld1 { v30.h }[0], [x28], #0x2\n"
-    "tbz x4, #0, 19f\n"
-    "ld1 { v30.b }[2], [x28]\n"
+    "tbz x0, #1, 18f\n"
+    "ld1 { v30.h }[0], [x15], #0x2\n"
+    "tbz x0, #0, 19f\n"
+    "ld1 { v30.b }[2], [x15]\n"
     "b 19f\n"
     "18:"  // Oddments: Load (1, 4): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 19f\n"
-    "ld1 { v30.b }[0], [x28]\n"
+    "tbz x0, #0, 19f\n"
+    "ld1 { v30.b }[0], [x15]\n"
     "19:"  // Oddments: Load (1, 4): Bit 2: End
+    "usubl v30.8h, v30.8b, v9.8b\n"
+    "ldr x19, [x20, #0x60]\n"
+    "smlal v6.4s, v30.4h, v3.4h\n"
+    "smlal2 v5.4s, v30.8h, v3.8h\n"
     "smlal v15.4s, v24.4h, v4.4h\n"
-    "ldr x0, [x25, #0x60]\n"
-    "usubl v30.8h, v30.8b, v7.8b\n"
-    "smlal2 v20.4s, v24.8h, v4.8h\n"
-    "add x0, x0, x10\n"
-    "smlal v10.4s, v30.4h, v3.4h\n"
-    "smlal2 v9.4s, v30.8h, v3.8h\n"
-    "tbz x4, #2, 21f\n"
-    "ld1 { v27.s }[0], [x0], #0x4\n"
-    "tbz x4, #1, 20f\n"
-    "ld1 { v27.h }[2], [x0], #0x2\n"
-    "tbz x4, #0, 23f\n"
-    "ld1 { v27.b }[6], [x0]\n"
+    "smlal2 v16.4s, v24.8h, v4.8h\n"
+    "add x19, x19, x24\n"
+    "tbz x0, #2, 21f\n"
+    "ld1 { v27.s }[0], [x19], #0x4\n"
+    "tbz x0, #1, 20f\n"
+    "ld1 { v27.h }[2], [x19], #0x2\n"
+    "tbz x0, #0, 23f\n"
+    "ld1 { v27.b }[6], [x19]\n"
     "b 23f\n"
     "20:"  // Oddments: Load (0, 5): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 23f\n"
-    "ld1 { v27.b }[4], [x0]\n"
+    "tbz x0, #0, 23f\n"
+    "ld1 { v27.b }[4], [x19]\n"
     "b 23f\n"
     "21:"  // Oddments: Load (0, 5): Bit 2: Unset
-    "tbz x4, #1, 22f\n"
-    "ld1 { v27.h }[0], [x0], #0x2\n"
-    "tbz x4, #0, 23f\n"
-    "ld1 { v27.b }[2], [x0]\n"
+    "tbz x0, #1, 22f\n"
+    "ld1 { v27.h }[0], [x19], #0x2\n"
+    "tbz x0, #0, 23f\n"
+    "ld1 { v27.b }[2], [x19]\n"
     "b 23f\n"
     "22:"  // Oddments: Load (0, 5): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 23f\n"
-    "ld1 { v27.b }[0], [x0]\n"
+    "tbz x0, #0, 23f\n"
+    "ld1 { v27.b }[0], [x19]\n"
     "23:"  // Oddments: Load (0, 5): Bit 2: End
-    "smlal v11.4s, v30.4h, v4.4h\n"
-    "ldr d0, [x3, #0x28]\n"
-    "usubl v27.8h, v27.8b, v7.8b\n"
-    "smlal2 v8.4s, v30.8h, v4.8h\n"
-    "ldr x7, [x25, #0x68]\n"
-    "smlal v10.4s, v26.4h, v4.4h\n"
-    "usubl v0.8h, v0.8b, v13.8b\n"
-    "smlal2 v9.4s, v26.8h, v4.8h\n"
-    "add x7, x7, x10\n"
-    "smlal v18.4s, v27.4h, v4.4h\n"
-    "smlal2 v5.4s, v27.8h, v4.8h\n"
+    "usubl v27.8h, v27.8b, v9.8b\n"
+    "ldr d0, [x23, #0x28]\n"
+    "smlal v17.4s, v27.4h, v4.4h\n"
+    "smlal2 v8.4s, v27.8h, v4.8h\n"
+    "smlal v10.4s, v30.4h, v4.4h\n"
+    "smlal2 v7.4s, v30.8h, v4.8h\n"
+    "usubl v0.8h, v0.8b, v14.8b\n"
+    "ldr x27, [x20, #0x68]\n"
+    "smlal v6.4s, v26.4h, v4.4h\n"
+    "smlal2 v5.4s, v26.8h, v4.8h\n"
+    "add x27, x27, x24\n"
     "smlal v15.4s, v29.4h, v0.4h\n"
-    "smlal2 v20.4s, v29.8h, v0.8h\n"
-    "smlal v18.4s, v28.4h, v0.4h\n"
-    "smlal2 v5.4s, v28.8h, v0.8h\n"
-    "smlal v11.4s, v22.4h, v0.4h\n"
-    "smlal2 v8.4s, v22.8h, v0.8h\n"
-    "tbz x4, #2, 25f\n"
-    "ld1 { v25.s }[0], [x7], #0x4\n"
-    "tbz x4, #1, 24f\n"
-    "ld1 { v25.h }[2], [x7], #0x2\n"
-    "tbz x4, #0, 27f\n"
-    "ld1 { v25.b }[6], [x7]\n"
+    "smlal2 v16.4s, v29.8h, v0.8h\n"
+    "smlal v17.4s, v28.4h, v0.4h\n"
+    "smlal2 v8.4s, v28.8h, v0.8h\n"
+    "smlal v10.4s, v22.4h, v0.4h\n"
+    "smlal2 v7.4s, v22.8h, v0.8h\n"
+    "tbz x0, #2, 25f\n"
+    "ld1 { v25.s }[0], [x27], #0x4\n"
+    "tbz x0, #1, 24f\n"
+    "ld1 { v25.h }[2], [x27], #0x2\n"
+    "tbz x0, #0, 27f\n"
+    "ld1 { v25.b }[6], [x27]\n"
     "b 27f\n"
     "24:"  // Oddments: Load (2, 1): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 27f\n"
-    "ld1 { v25.b }[4], [x7]\n"
+    "tbz x0, #0, 27f\n"
+    "ld1 { v25.b }[4], [x27]\n"
     "b 27f\n"
     "25:"  // Oddments: Load (2, 1): Bit 2: Unset
-    "tbz x4, #1, 26f\n"
-    "ld1 { v25.h }[0], [x7], #0x2\n"
-    "tbz x4, #0, 27f\n"
-    "ld1 { v25.b }[2], [x7]\n"
+    "tbz x0, #1, 26f\n"
+    "ld1 { v25.h }[0], [x27], #0x2\n"
+    "tbz x0, #0, 27f\n"
+    "ld1 { v25.b }[2], [x27]\n"
     "b 27f\n"
     "26:"  // Oddments: Load (2, 1): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 27f\n"
-    "ld1 { v25.b }[0], [x7]\n"
+    "tbz x0, #0, 27f\n"
+    "ld1 { v25.b }[0], [x27]\n"
     "27:"  // Oddments: Load (2, 1): Bit 2: End
-    "ldr d1, [x3, #0x30]\n"
-    "usubl v25.8h, v25.8b, v7.8b\n"
-    "smlal v10.4s, v25.4h, v0.4h\n"
-    "ldr x26, [x25, #0x70]\n"
-    "usubl v1.8h, v1.8b, v13.8b\n"
-    "smlal2 v9.4s, v25.8h, v0.8h\n"
-    "add x26, x26, x10\n"
+    "ldr d1, [x23, #0x30]\n"
+    "usubl v25.8h, v25.8b, v9.8b\n"
+    "usubl v1.8h, v1.8b, v14.8b\n"
+    "ldr x5, [x20, #0x70]\n"
+    "smlal v6.4s, v25.4h, v0.4h\n"
+    "smlal2 v5.4s, v25.8h, v0.8h\n"
+    "add x5, x5, x24\n"
     "smlal v15.4s, v28.4h, v1.4h\n"
-    "smlal2 v20.4s, v28.8h, v1.8h\n"
-    "smlal v18.4s, v23.4h, v1.4h\n"
-    "smlal2 v5.4s, v23.8h, v1.8h\n"
-    "smlal v11.4s, v25.4h, v1.4h\n"
-    "smlal2 v8.4s, v25.8h, v1.8h\n"
-    "tbz x4, #2, 29f\n"
-    "ld1 { v24.s }[0], [x26], #0x4\n"
-    "tbz x4, #1, 28f\n"
-    "ld1 { v24.h }[2], [x26], #0x2\n"
-    "tbz x4, #0, 31f\n"
-    "ld1 { v24.b }[6], [x26]\n"
+    "smlal2 v16.4s, v28.8h, v1.8h\n"
+    "smlal v17.4s, v23.4h, v1.4h\n"
+    "smlal2 v8.4s, v23.8h, v1.8h\n"
+    "smlal v10.4s, v25.4h, v1.4h\n"
+    "smlal2 v7.4s, v25.8h, v1.8h\n"
+    "tbz x0, #2, 29f\n"
+    "ld1 { v24.s }[0], [x5], #0x4\n"
+    "tbz x0, #1, 28f\n"
+    "ld1 { v24.h }[2], [x5], #0x2\n"
+    "tbz x0, #0, 31f\n"
+    "ld1 { v24.b }[6], [x5]\n"
     "b 31f\n"
     "28:"  // Oddments: Load (2, 2): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 31f\n"
-    "ld1 { v24.b }[4], [x26]\n"
+    "tbz x0, #0, 31f\n"
+    "ld1 { v24.b }[4], [x5]\n"
     "b 31f\n"
     "29:"  // Oddments: Load (2, 2): Bit 2: Unset
-    "tbz x4, #1, 30f\n"
-    "ld1 { v24.h }[0], [x26], #0x2\n"
-    "tbz x4, #0, 31f\n"
-    "ld1 { v24.b }[2], [x26]\n"
+    "tbz x0, #1, 30f\n"
+    "ld1 { v24.h }[0], [x5], #0x2\n"
+    "tbz x0, #0, 31f\n"
+    "ld1 { v24.b }[2], [x5]\n"
     "b 31f\n"
     "30:"  // Oddments: Load (2, 2): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 31f\n"
-    "ld1 { v24.b }[0], [x26]\n"
+    "tbz x0, #0, 31f\n"
+    "ld1 { v24.b }[0], [x5]\n"
     "31:"  // Oddments: Load (2, 2): Bit 2: End
-    "ldr d2, [x3, #0x38]\n"
-    "usubl v24.8h, v24.8b, v7.8b\n"
-    "smlal v10.4s, v24.4h, v1.4h\n"
-    "ldr x23, [x25, #0x78]\n"
-    "usubl v2.8h, v2.8b, v13.8b\n"
-    "smlal2 v9.4s, v24.8h, v1.8h\n"
-    "add x23, x23, x10\n"
+    "ldr d2, [x23, #0x38]\n"
+    "usubl v24.8h, v24.8b, v9.8b\n"
+    "usubl v2.8h, v2.8b, v14.8b\n"
+    "ldr x11, [x20, #0x78]\n"
+    "smlal v6.4s, v24.4h, v1.4h\n"
+    "smlal2 v5.4s, v24.8h, v1.8h\n"
+    "add x11, x11, x24\n"
     "smlal v15.4s, v23.4h, v2.4h\n"
-    "smlal2 v20.4s, v23.8h, v2.8h\n"
-    "smlal v18.4s, v31.4h, v2.4h\n"
-    "smlal2 v5.4s, v31.8h, v2.8h\n"
-    "smlal v11.4s, v24.4h, v2.4h\n"
-    "smlal2 v8.4s, v24.8h, v2.8h\n"
-    "tbz x4, #2, 33f\n"
-    "ld1 { v27.s }[0], [x23], #0x4\n"
-    "tbz x4, #1, 32f\n"
-    "ld1 { v27.h }[2], [x23], #0x2\n"
-    "tbz x4, #0, 35f\n"
-    "ld1 { v27.b }[6], [x23]\n"
+    "smlal2 v16.4s, v23.8h, v2.8h\n"
+    "smlal v17.4s, v31.4h, v2.4h\n"
+    "smlal2 v8.4s, v31.8h, v2.8h\n"
+    "smlal v10.4s, v24.4h, v2.4h\n"
+    "smlal2 v7.4s, v24.8h, v2.8h\n"
+    "tbz x0, #2, 33f\n"
+    "ld1 { v27.s }[0], [x11], #0x4\n"
+    "tbz x0, #1, 32f\n"
+    "ld1 { v27.h }[2], [x11], #0x2\n"
+    "tbz x0, #0, 35f\n"
+    "ld1 { v27.b }[6], [x11]\n"
     "b 35f\n"
     "32:"  // Oddments: Load (2, 3): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 35f\n"
-    "ld1 { v27.b }[4], [x23]\n"
+    "tbz x0, #0, 35f\n"
+    "ld1 { v27.b }[4], [x11]\n"
     "b 35f\n"
     "33:"  // Oddments: Load (2, 3): Bit 2: Unset
-    "tbz x4, #1, 34f\n"
-    "ld1 { v27.h }[0], [x23], #0x2\n"
-    "tbz x4, #0, 35f\n"
-    "ld1 { v27.b }[2], [x23]\n"
+    "tbz x0, #1, 34f\n"
+    "ld1 { v27.h }[0], [x11], #0x2\n"
+    "tbz x0, #0, 35f\n"
+    "ld1 { v27.b }[2], [x11]\n"
     "b 35f\n"
     "34:"  // Oddments: Load (2, 3): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 35f\n"
-    "ld1 { v27.b }[0], [x23]\n"
+    "tbz x0, #0, 35f\n"
+    "ld1 { v27.b }[0], [x11]\n"
     "35:"  // Oddments: Load (2, 3): Bit 2: End
-    "ldr d3, [x3, #0x40]\n"
-    "usubl v27.8h, v27.8b, v7.8b\n"
-    "smlal v10.4s, v27.4h, v2.4h\n"
-    "ldr x20, [x25, #0x80]\n"
-    "usubl v3.8h, v3.8b, v13.8b\n"
-    "smlal2 v9.4s, v27.8h, v2.8h\n"
-    "add x20, x20, x10\n"
+    "ldr d3, [x23, #0x40]\n"
+    "usubl v27.8h, v27.8b, v9.8b\n"
+    "usubl v3.8h, v3.8b, v14.8b\n"
+    "ldr x12, [x20, #0x80]\n"
+    "smlal v6.4s, v27.4h, v2.4h\n"
+    "smlal2 v5.4s, v27.8h, v2.8h\n"
+    "add x12, x12, x24\n"
     "smlal v15.4s, v31.4h, v3.4h\n"
-    "smlal2 v20.4s, v31.8h, v3.8h\n"
-    "smlal v18.4s, v30.4h, v3.4h\n"
-    "smlal2 v5.4s, v30.8h, v3.8h\n"
-    "smlal v11.4s, v27.4h, v3.4h\n"
-    "smlal2 v8.4s, v27.8h, v3.8h\n"
-    "tbz x4, #2, 37f\n"
-    "ld1 { v23.s }[0], [x20], #0x4\n"
-    "tbz x4, #1, 36f\n"
-    "ld1 { v23.h }[2], [x20], #0x2\n"
-    "tbz x4, #0, 39f\n"
-    "ld1 { v23.b }[6], [x20]\n"
+    "smlal2 v16.4s, v31.8h, v3.8h\n"
+    "smlal v17.4s, v30.4h, v3.4h\n"
+    "smlal2 v8.4s, v30.8h, v3.8h\n"
+    "smlal v10.4s, v27.4h, v3.4h\n"
+    "smlal2 v7.4s, v27.8h, v3.8h\n"
+    "tbz x0, #2, 37f\n"
+    "ld1 { v23.s }[0], [x12], #0x4\n"
+    "tbz x0, #1, 36f\n"
+    "ld1 { v23.h }[2], [x12], #0x2\n"
+    "tbz x0, #0, 39f\n"
+    "ld1 { v23.b }[6], [x12]\n"
     "b 39f\n"
     "36:"  // Oddments: Load (2, 4): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 39f\n"
-    "ld1 { v23.b }[4], [x20]\n"
+    "tbz x0, #0, 39f\n"
+    "ld1 { v23.b }[4], [x12]\n"
     "b 39f\n"
     "37:"  // Oddments: Load (2, 4): Bit 2: Unset
-    "tbz x4, #1, 38f\n"
-    "ld1 { v23.h }[0], [x20], #0x2\n"
-    "tbz x4, #0, 39f\n"
-    "ld1 { v23.b }[2], [x20]\n"
+    "tbz x0, #1, 38f\n"
+    "ld1 { v23.h }[0], [x12], #0x2\n"
+    "tbz x0, #0, 39f\n"
+    "ld1 { v23.b }[2], [x12]\n"
     "b 39f\n"
     "38:"  // Oddments: Load (2, 4): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 39f\n"
-    "ld1 { v23.b }[0], [x20]\n"
+    "tbz x0, #0, 39f\n"
+    "ld1 { v23.b }[0], [x12]\n"
     "39:"  // Oddments: Load (2, 4): Bit 2: End
-    "ldr d4, [x3, #0x48]\n"
-    "usubl v23.8h, v23.8b, v7.8b\n"
-    "smlal v10.4s, v23.4h, v3.4h\n"
-    "ldr x22, [x25, #0x88]\n"
-    "usubl v4.8h, v4.8b, v13.8b\n"
-    "smlal2 v9.4s, v23.8h, v3.8h\n"
-    "add x22, x22, x10\n"
+    "ldr d4, [x23, #0x48]\n"
+    "usubl v23.8h, v23.8b, v9.8b\n"
+    "usubl v4.8h, v4.8b, v14.8b\n"
+    "ldr x26, [x20, #0x88]\n"
+    "smlal v6.4s, v23.4h, v3.4h\n"
+    "smlal2 v5.4s, v23.8h, v3.8h\n"
+    "add x26, x26, x24\n"
     "smlal v15.4s, v30.4h, v4.4h\n"
-    "smlal2 v20.4s, v30.8h, v4.8h\n"
-    "smlal v18.4s, v26.4h, v4.4h\n"
-    "smlal2 v5.4s, v26.8h, v4.8h\n"
-    "smlal v11.4s, v23.4h, v4.4h\n"
-    "smlal2 v8.4s, v23.8h, v4.8h\n"
-    "tbz x4, #2, 41f\n"
-    "ld1 { v28.s }[0], [x22], #0x4\n"
-    "tbz x4, #1, 40f\n"
-    "ld1 { v28.h }[2], [x22], #0x2\n"
-    "tbz x4, #0, 43f\n"
-    "ld1 { v28.b }[6], [x22]\n"
+    "smlal2 v16.4s, v30.8h, v4.8h\n"
+    "smlal v17.4s, v26.4h, v4.4h\n"
+    "smlal2 v8.4s, v26.8h, v4.8h\n"
+    "smlal v10.4s, v23.4h, v4.4h\n"
+    "smlal2 v7.4s, v23.8h, v4.8h\n"
+    "tbz x0, #2, 41f\n"
+    "ld1 { v28.s }[0], [x26], #0x4\n"
+    "tbz x0, #1, 40f\n"
+    "ld1 { v28.h }[2], [x26], #0x2\n"
+    "tbz x0, #0, 43f\n"
+    "ld1 { v28.b }[6], [x26]\n"
     "b 43f\n"
     "40:"  // Oddments: Load (2, 5): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 43f\n"
-    "ld1 { v28.b }[4], [x22]\n"
+    "tbz x0, #0, 43f\n"
+    "ld1 { v28.b }[4], [x26]\n"
     "b 43f\n"
     "41:"  // Oddments: Load (2, 5): Bit 2: Unset
-    "tbz x4, #1, 42f\n"
-    "ld1 { v28.h }[0], [x22], #0x2\n"
-    "tbz x4, #0, 43f\n"
-    "ld1 { v28.b }[2], [x22]\n"
+    "tbz x0, #1, 42f\n"
+    "ld1 { v28.h }[0], [x26], #0x2\n"
+    "tbz x0, #0, 43f\n"
+    "ld1 { v28.b }[2], [x26]\n"
     "b 43f\n"
     "42:"  // Oddments: Load (2, 5): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 43f\n"
-    "ld1 { v28.b }[0], [x22]\n"
+    "tbz x0, #0, 43f\n"
+    "ld1 { v28.b }[0], [x26]\n"
     "43:"  // Oddments: Load (2, 5): Bit 2: End
-    "ldr d0, [x3, #0x50]\n"
-    "usubl v28.8h, v28.8b, v7.8b\n"
-    "smlal v10.4s, v28.4h, v4.4h\n"
-    "ldr x13, [x25, #0x90]\n"
-    "usubl v0.8h, v0.8b, v13.8b\n"
-    "smlal2 v9.4s, v28.8h, v4.8h\n"
-    "add x13, x13, x10\n"
+    "ldr d0, [x23, #0x50]\n"
+    "usubl v28.8h, v28.8b, v9.8b\n"
+    "usubl v0.8h, v0.8b, v14.8b\n"
+    "ldr x14, [x20, #0x90]\n"
+    "smlal v6.4s, v28.4h, v4.4h\n"
+    "smlal2 v5.4s, v28.8h, v4.8h\n"
+    "add x14, x14, x24\n"
     "smlal v15.4s, v22.4h, v0.4h\n"
-    "smlal2 v20.4s, v22.8h, v0.8h\n"
-    "smlal v18.4s, v25.4h, v0.4h\n"
-    "smlal2 v5.4s, v25.8h, v0.8h\n"
-    "tbz x4, #2, 45f\n"
-    "ld1 { v31.s }[0], [x13], #0x4\n"
-    "tbz x4, #1, 44f\n"
-    "ld1 { v31.h }[2], [x13], #0x2\n"
-    "tbz x4, #0, 47f\n"
-    "ld1 { v31.b }[6], [x13]\n"
+    "smlal2 v16.4s, v22.8h, v0.8h\n"
+    "smlal v17.4s, v25.4h, v0.4h\n"
+    "smlal2 v8.4s, v25.8h, v0.8h\n"
+    "tbz x0, #2, 45f\n"
+    "ld1 { v31.s }[0], [x14], #0x4\n"
+    "tbz x0, #1, 44f\n"
+    "ld1 { v31.h }[2], [x14], #0x2\n"
+    "tbz x0, #0, 47f\n"
+    "ld1 { v31.b }[6], [x14]\n"
     "b 47f\n"
     "44:"  // Oddments: Load (3, 0): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 47f\n"
-    "ld1 { v31.b }[4], [x13]\n"
+    "tbz x0, #0, 47f\n"
+    "ld1 { v31.b }[4], [x14]\n"
     "b 47f\n"
     "45:"  // Oddments: Load (3, 0): Bit 2: Unset
-    "tbz x4, #1, 46f\n"
-    "ld1 { v31.h }[0], [x13], #0x2\n"
-    "tbz x4, #0, 47f\n"
-    "ld1 { v31.b }[2], [x13]\n"
+    "tbz x0, #1, 46f\n"
+    "ld1 { v31.h }[0], [x14], #0x2\n"
+    "tbz x0, #0, 47f\n"
+    "ld1 { v31.b }[2], [x14]\n"
     "b 47f\n"
     "46:"  // Oddments: Load (3, 0): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 47f\n"
-    "ld1 { v31.b }[0], [x13]\n"
+    "tbz x0, #0, 47f\n"
+    "ld1 { v31.b }[0], [x14]\n"
     "47:"  // Oddments: Load (3, 0): Bit 2: End
-    "ldr x21, [x25, #0x98]\n"
-    "usubl v31.8h, v31.8b, v7.8b\n"
-    "smlal v11.4s, v31.4h, v0.4h\n"
-    "smlal2 v8.4s, v31.8h, v0.8h\n"
-    "add x21, x21, x10\n"
-    "tbz x4, #2, 49f\n"
-    "ld1 { v30.s }[0], [x21], #0x4\n"
-    "tbz x4, #1, 48f\n"
-    "ld1 { v30.h }[2], [x21], #0x2\n"
-    "tbz x4, #0, 51f\n"
-    "ld1 { v30.b }[6], [x21]\n"
+    "usubl v31.8h, v31.8b, v9.8b\n"
+    "ldr x15, [x20, #0x98]\n"
+    "smlal v10.4s, v31.4h, v0.4h\n"
+    "smlal2 v7.4s, v31.8h, v0.8h\n"
+    "add x15, x15, x24\n"
+    "tbz x0, #2, 49f\n"
+    "ld1 { v30.s }[0], [x15], #0x4\n"
+    "tbz x0, #1, 48f\n"
+    "ld1 { v30.h }[2], [x15], #0x2\n"
+    "tbz x0, #0, 51f\n"
+    "ld1 { v30.b }[6], [x15]\n"
     "b 51f\n"
     "48:"  // Oddments: Load (3, 1): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 51f\n"
-    "ld1 { v30.b }[4], [x21]\n"
+    "tbz x0, #0, 51f\n"
+    "ld1 { v30.b }[4], [x15]\n"
     "b 51f\n"
     "49:"  // Oddments: Load (3, 1): Bit 2: Unset
-    "tbz x4, #1, 50f\n"
-    "ld1 { v30.h }[0], [x21], #0x2\n"
-    "tbz x4, #0, 51f\n"
-    "ld1 { v30.b }[2], [x21]\n"
+    "tbz x0, #1, 50f\n"
+    "ld1 { v30.h }[0], [x15], #0x2\n"
+    "tbz x0, #0, 51f\n"
+    "ld1 { v30.b }[2], [x15]\n"
     "b 51f\n"
     "50:"  // Oddments: Load (3, 1): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 51f\n"
-    "ld1 { v30.b }[0], [x21]\n"
+    "tbz x0, #0, 51f\n"
+    "ld1 { v30.b }[0], [x15]\n"
     "51:"  // Oddments: Load (3, 1): Bit 2: End
-    "ldr d1, [x3, #0x58]\n"
-    "usubl v30.8h, v30.8b, v7.8b\n"
-    "smlal v10.4s, v30.4h, v0.4h\n"
-    "ldr x14, [x25, #0xa0]\n"
-    "usubl v1.8h, v1.8b, v13.8b\n"
-    "smlal2 v9.4s, v30.8h, v0.8h\n"
-    "add x14, x14, x10\n"
+    "ldr d1, [x23, #0x58]\n"
+    "usubl v30.8h, v30.8b, v9.8b\n"
+    "usubl v1.8h, v1.8b, v14.8b\n"
+    "ldr x21, [x20, #0xa0]\n"
+    "smlal v6.4s, v30.4h, v0.4h\n"
+    "smlal2 v5.4s, v30.8h, v0.8h\n"
+    "add x21, x21, x24\n"
     "smlal v15.4s, v25.4h, v1.4h\n"
-    "smlal2 v20.4s, v25.8h, v1.8h\n"
-    "smlal v18.4s, v24.4h, v1.4h\n"
-    "smlal2 v5.4s, v24.8h, v1.8h\n"
-    "smlal v11.4s, v30.4h, v1.4h\n"
-    "smlal2 v8.4s, v30.8h, v1.8h\n"
-    "tbz x4, #2, 53f\n"
-    "ld1 { v26.s }[0], [x14], #0x4\n"
-    "tbz x4, #1, 52f\n"
-    "ld1 { v26.h }[2], [x14], #0x2\n"
-    "tbz x4, #0, 55f\n"
-    "ld1 { v26.b }[6], [x14]\n"
+    "smlal2 v16.4s, v25.8h, v1.8h\n"
+    "smlal v17.4s, v24.4h, v1.4h\n"
+    "smlal2 v8.4s, v24.8h, v1.8h\n"
+    "smlal v10.4s, v30.4h, v1.4h\n"
+    "smlal2 v7.4s, v30.8h, v1.8h\n"
+    "tbz x0, #2, 53f\n"
+    "ld1 { v26.s }[0], [x21], #0x4\n"
+    "tbz x0, #1, 52f\n"
+    "ld1 { v26.h }[2], [x21], #0x2\n"
+    "tbz x0, #0, 55f\n"
+    "ld1 { v26.b }[6], [x21]\n"
     "b 55f\n"
     "52:"  // Oddments: Load (3, 2): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 55f\n"
-    "ld1 { v26.b }[4], [x14]\n"
+    "tbz x0, #0, 55f\n"
+    "ld1 { v26.b }[4], [x21]\n"
     "b 55f\n"
     "53:"  // Oddments: Load (3, 2): Bit 2: Unset
-    "tbz x4, #1, 54f\n"
-    "ld1 { v26.h }[0], [x14], #0x2\n"
-    "tbz x4, #0, 55f\n"
-    "ld1 { v26.b }[2], [x14]\n"
+    "tbz x0, #1, 54f\n"
+    "ld1 { v26.h }[0], [x21], #0x2\n"
+    "tbz x0, #0, 55f\n"
+    "ld1 { v26.b }[2], [x21]\n"
     "b 55f\n"
     "54:"  // Oddments: Load (3, 2): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 55f\n"
-    "ld1 { v26.b }[0], [x14]\n"
+    "tbz x0, #0, 55f\n"
+    "ld1 { v26.b }[0], [x21]\n"
     "55:"  // Oddments: Load (3, 2): Bit 2: End
-    "ldr d2, [x3, #0x60]\n"
-    "usubl v26.8h, v26.8b, v7.8b\n"
-    "smlal v10.4s, v26.4h, v1.4h\n"
-    "ldr x11, [x25, #0xa8]\n"
-    "usubl v2.8h, v2.8b, v13.8b\n"
-    "smlal2 v9.4s, v26.8h, v1.8h\n"
-    "add x11, x11, x10\n"
+    "ldr d2, [x23, #0x60]\n"
+    "usubl v26.8h, v26.8b, v9.8b\n"
+    "usubl v2.8h, v2.8b, v14.8b\n"
+    "ldr x2, [x20, #0xa8]\n"
+    "smlal v6.4s, v26.4h, v1.4h\n"
+    "smlal2 v5.4s, v26.8h, v1.8h\n"
+    "add x2, x2, x24\n"
     "smlal v15.4s, v24.4h, v2.4h\n"
-    "smlal2 v20.4s, v24.8h, v2.8h\n"
-    "smlal v18.4s, v27.4h, v2.4h\n"
-    "smlal2 v5.4s, v27.8h, v2.8h\n"
-    "smlal v11.4s, v26.4h, v2.4h\n"
-    "smlal2 v8.4s, v26.8h, v2.8h\n"
-    "tbz x4, #2, 57f\n"
-    "ld1 { v25.s }[0], [x11], #0x4\n"
-    "tbz x4, #1, 56f\n"
-    "ld1 { v25.h }[2], [x11], #0x2\n"
-    "tbz x4, #0, 59f\n"
-    "ld1 { v25.b }[6], [x11]\n"
+    "smlal2 v16.4s, v24.8h, v2.8h\n"
+    "smlal v17.4s, v27.4h, v2.4h\n"
+    "smlal2 v8.4s, v27.8h, v2.8h\n"
+    "smlal v10.4s, v26.4h, v2.4h\n"
+    "smlal2 v7.4s, v26.8h, v2.8h\n"
+    "tbz x0, #2, 57f\n"
+    "ld1 { v25.s }[0], [x2], #0x4\n"
+    "tbz x0, #1, 56f\n"
+    "ld1 { v25.h }[2], [x2], #0x2\n"
+    "tbz x0, #0, 59f\n"
+    "ld1 { v25.b }[6], [x2]\n"
     "b 59f\n"
     "56:"  // Oddments: Load (3, 3): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 59f\n"
-    "ld1 { v25.b }[4], [x11]\n"
+    "tbz x0, #0, 59f\n"
+    "ld1 { v25.b }[4], [x2]\n"
     "b 59f\n"
     "57:"  // Oddments: Load (3, 3): Bit 2: Unset
-    "tbz x4, #1, 58f\n"
-    "ld1 { v25.h }[0], [x11], #0x2\n"
-    "tbz x4, #0, 59f\n"
-    "ld1 { v25.b }[2], [x11]\n"
+    "tbz x0, #1, 58f\n"
+    "ld1 { v25.h }[0], [x2], #0x2\n"
+    "tbz x0, #0, 59f\n"
+    "ld1 { v25.b }[2], [x2]\n"
     "b 59f\n"
     "58:"  // Oddments: Load (3, 3): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 59f\n"
-    "ld1 { v25.b }[0], [x11]\n"
+    "tbz x0, #0, 59f\n"
+    "ld1 { v25.b }[0], [x2]\n"
     "59:"  // Oddments: Load (3, 3): Bit 2: End
-    "ldr d3, [x3, #0x68]\n"
-    "usubl v25.8h, v25.8b, v7.8b\n"
-    "smlal v10.4s, v25.4h, v2.4h\n"
-    "ldr x24, [x25, #0xb0]\n"
-    "usubl v3.8h, v3.8b, v13.8b\n"
-    "smlal2 v9.4s, v25.8h, v2.8h\n"
-    "add x24, x24, x10\n"
+    "ldr d3, [x23, #0x68]\n"
+    "usubl v25.8h, v25.8b, v9.8b\n"
+    "usubl v3.8h, v3.8b, v14.8b\n"
+    "ldr x13, [x20, #0xb0]\n"
+    "smlal v6.4s, v25.4h, v2.4h\n"
+    "smlal2 v5.4s, v25.8h, v2.8h\n"
+    "add x13, x13, x24\n"
     "smlal v15.4s, v27.4h, v3.4h\n"
-    "smlal2 v20.4s, v27.8h, v3.8h\n"
-    "smlal v18.4s, v23.4h, v3.4h\n"
-    "smlal2 v5.4s, v23.8h, v3.8h\n"
-    "smlal v11.4s, v25.4h, v3.4h\n"
-    "smlal2 v8.4s, v25.8h, v3.8h\n"
-    "tbz x4, #2, 61f\n"
-    "ld1 { v24.s }[0], [x24], #0x4\n"
-    "tbz x4, #1, 60f\n"
-    "ld1 { v24.h }[2], [x24], #0x2\n"
-    "tbz x4, #0, 63f\n"
-    "ld1 { v24.b }[6], [x24]\n"
+    "smlal2 v16.4s, v27.8h, v3.8h\n"
+    "smlal v17.4s, v23.4h, v3.4h\n"
+    "smlal2 v8.4s, v23.8h, v3.8h\n"
+    "smlal v10.4s, v25.4h, v3.4h\n"
+    "smlal2 v7.4s, v25.8h, v3.8h\n"
+    "tbz x0, #2, 61f\n"
+    "ld1 { v24.s }[0], [x13], #0x4\n"
+    "tbz x0, #1, 60f\n"
+    "ld1 { v24.h }[2], [x13], #0x2\n"
+    "tbz x0, #0, 63f\n"
+    "ld1 { v24.b }[6], [x13]\n"
     "b 63f\n"
     "60:"  // Oddments: Load (3, 4): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 63f\n"
-    "ld1 { v24.b }[4], [x24]\n"
+    "tbz x0, #0, 63f\n"
+    "ld1 { v24.b }[4], [x13]\n"
     "b 63f\n"
     "61:"  // Oddments: Load (3, 4): Bit 2: Unset
-    "tbz x4, #1, 62f\n"
-    "ld1 { v24.h }[0], [x24], #0x2\n"
-    "tbz x4, #0, 63f\n"
-    "ld1 { v24.b }[2], [x24]\n"
+    "tbz x0, #1, 62f\n"
+    "ld1 { v24.h }[0], [x13], #0x2\n"
+    "tbz x0, #0, 63f\n"
+    "ld1 { v24.b }[2], [x13]\n"
     "b 63f\n"
     "62:"  // Oddments: Load (3, 4): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 63f\n"
-    "ld1 { v24.b }[0], [x24]\n"
+    "tbz x0, #0, 63f\n"
+    "ld1 { v24.b }[0], [x13]\n"
     "63:"  // Oddments: Load (3, 4): Bit 2: End
-    "ldr d4, [x3, #0x70]\n"
-    "usubl v24.8h, v24.8b, v7.8b\n"
-    "smlal v10.4s, v24.4h, v3.4h\n"
-    "ldr x0, [x25, #0xb8]\n"
-    "usubl v4.8h, v4.8b, v13.8b\n"
-    "smlal2 v9.4s, v24.8h, v3.8h\n"
-    "add x0, x0, x10\n"
+    "ldr d4, [x23, #0x70]\n"
+    "usubl v24.8h, v24.8b, v9.8b\n"
+    "usubl v4.8h, v4.8b, v14.8b\n"
+    "ldr x9, [x20, #0xb8]\n"
+    "smlal v6.4s, v24.4h, v3.4h\n"
+    "smlal2 v5.4s, v24.8h, v3.8h\n"
+    "add x9, x9, x24\n"
     "smlal v15.4s, v23.4h, v4.4h\n"
-    "smlal2 v20.4s, v23.8h, v4.8h\n"
-    "smlal v18.4s, v28.4h, v4.4h\n"
-    "smlal2 v5.4s, v28.8h, v4.8h\n"
-    "smlal v11.4s, v24.4h, v4.4h\n"
-    "smlal2 v8.4s, v24.8h, v4.8h\n"
-    "tbz x4, #2, 65f\n"
-    "ld1 { v22.s }[0], [x0], #0x4\n"
-    "tbz x4, #1, 64f\n"
-    "ld1 { v22.h }[2], [x0], #0x2\n"
-    "tbz x4, #0, 67f\n"
-    "ld1 { v22.b }[6], [x0]\n"
+    "smlal2 v16.4s, v23.8h, v4.8h\n"
+    "smlal v17.4s, v28.4h, v4.4h\n"
+    "smlal2 v8.4s, v28.8h, v4.8h\n"
+    "smlal v10.4s, v24.4h, v4.4h\n"
+    "smlal2 v7.4s, v24.8h, v4.8h\n"
+    "tbz x0, #2, 65f\n"
+    "ld1 { v22.s }[0], [x9], #0x4\n"
+    "tbz x0, #1, 64f\n"
+    "ld1 { v22.h }[2], [x9], #0x2\n"
+    "tbz x0, #0, 67f\n"
+    "ld1 { v22.b }[6], [x9]\n"
     "b 67f\n"
     "64:"  // Oddments: Load (3, 5): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 67f\n"
-    "ld1 { v22.b }[4], [x0]\n"
+    "tbz x0, #0, 67f\n"
+    "ld1 { v22.b }[4], [x9]\n"
     "b 67f\n"
     "65:"  // Oddments: Load (3, 5): Bit 2: Unset
-    "tbz x4, #1, 66f\n"
-    "ld1 { v22.h }[0], [x0], #0x2\n"
-    "tbz x4, #0, 67f\n"
-    "ld1 { v22.b }[2], [x0]\n"
+    "tbz x0, #1, 66f\n"
+    "ld1 { v22.h }[0], [x9], #0x2\n"
+    "tbz x0, #0, 67f\n"
+    "ld1 { v22.b }[2], [x9]\n"
     "b 67f\n"
     "66:"  // Oddments: Load (3, 5): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 67f\n"
-    "ld1 { v22.b }[0], [x0]\n"
+    "tbz x0, #0, 67f\n"
+    "ld1 { v22.b }[0], [x9]\n"
     "67:"  // Oddments: Load (3, 5): Bit 2: End
-    "ldr d0, [x3, #0x78]\n"
-    "usubl v22.8h, v22.8b, v7.8b\n"
-    "smlal v10.4s, v22.4h, v4.4h\n"
-    "ldr x15, [x25, #0xc0]\n"
-    "usubl v0.8h, v0.8b, v13.8b\n"
-    "smlal2 v9.4s, v22.8h, v4.8h\n"
-    "add x15, x15, x10\n"
+    "ldr d0, [x23, #0x78]\n"
+    "usubl v22.8h, v22.8b, v9.8b\n"
+    "usubl v0.8h, v0.8b, v14.8b\n"
+    "ldr x19, [x20, #0xc0]\n"
+    "smlal v6.4s, v22.4h, v4.4h\n"
+    "smlal2 v5.4s, v22.8h, v4.8h\n"
+    "add x19, x19, x24\n"
     "smlal v15.4s, v31.4h, v0.4h\n"
-    "smlal2 v20.4s, v31.8h, v0.8h\n"
-    "smlal v18.4s, v30.4h, v0.4h\n"
-    "smlal2 v5.4s, v30.8h, v0.8h\n"
-    "tbz x4, #2, 69f\n"
-    "ld1 { v27.s }[0], [x15], #0x4\n"
-    "tbz x4, #1, 68f\n"
-    "ld1 { v27.h }[2], [x15], #0x2\n"
-    "tbz x4, #0, 71f\n"
-    "ld1 { v27.b }[6], [x15]\n"
+    "smlal2 v16.4s, v31.8h, v0.8h\n"
+    "smlal v17.4s, v30.4h, v0.4h\n"
+    "smlal2 v8.4s, v30.8h, v0.8h\n"
+    "tbz x0, #2, 69f\n"
+    "ld1 { v27.s }[0], [x19], #0x4\n"
+    "tbz x0, #1, 68f\n"
+    "ld1 { v27.h }[2], [x19], #0x2\n"
+    "tbz x0, #0, 71f\n"
+    "ld1 { v27.b }[6], [x19]\n"
     "b 71f\n"
     "68:"  // Oddments: Load (4, 0): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 71f\n"
-    "ld1 { v27.b }[4], [x15]\n"
+    "tbz x0, #0, 71f\n"
+    "ld1 { v27.b }[4], [x19]\n"
     "b 71f\n"
     "69:"  // Oddments: Load (4, 0): Bit 2: Unset
-    "tbz x4, #1, 70f\n"
-    "ld1 { v27.h }[0], [x15], #0x2\n"
-    "tbz x4, #0, 71f\n"
-    "ld1 { v27.b }[2], [x15]\n"
+    "tbz x0, #1, 70f\n"
+    "ld1 { v27.h }[0], [x19], #0x2\n"
+    "tbz x0, #0, 71f\n"
+    "ld1 { v27.b }[2], [x19]\n"
     "b 71f\n"
     "70:"  // Oddments: Load (4, 0): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 71f\n"
-    "ld1 { v27.b }[0], [x15]\n"
+    "tbz x0, #0, 71f\n"
+    "ld1 { v27.b }[0], [x19]\n"
     "71:"  // Oddments: Load (4, 0): Bit 2: End
-    "ldr x9, [x25, #0xc8]\n"
-    "usubl v27.8h, v27.8b, v7.8b\n"
-    "smlal v11.4s, v27.4h, v0.4h\n"
-    "smlal2 v8.4s, v27.8h, v0.8h\n"
-    "add x9, x9, x10\n"
-    "tbz x4, #2, 73f\n"
-    "ld1 { v23.s }[0], [x9], #0x4\n"
-    "tbz x4, #1, 72f\n"
-    "ld1 { v23.h }[2], [x9], #0x2\n"
-    "tbz x4, #0, 75f\n"
-    "ld1 { v23.b }[6], [x9]\n"
+    "usubl v27.8h, v27.8b, v9.8b\n"
+    "ldr x28, [x20, #0xc8]\n"
+    "smlal v10.4s, v27.4h, v0.4h\n"
+    "smlal2 v7.4s, v27.8h, v0.8h\n"
+    "add x28, x28, x24\n"
+    "tbz x0, #2, 73f\n"
+    "ld1 { v23.s }[0], [x28], #0x4\n"
+    "tbz x0, #1, 72f\n"
+    "ld1 { v23.h }[2], [x28], #0x2\n"
+    "tbz x0, #0, 75f\n"
+    "ld1 { v23.b }[6], [x28]\n"
     "b 75f\n"
     "72:"  // Oddments: Load (4, 1): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 75f\n"
-    "ld1 { v23.b }[4], [x9]\n"
+    "tbz x0, #0, 75f\n"
+    "ld1 { v23.b }[4], [x28]\n"
     "b 75f\n"
     "73:"  // Oddments: Load (4, 1): Bit 2: Unset
-    "tbz x4, #1, 74f\n"
-    "ld1 { v23.h }[0], [x9], #0x2\n"
-    "tbz x4, #0, 75f\n"
-    "ld1 { v23.b }[2], [x9]\n"
+    "tbz x0, #1, 74f\n"
+    "ld1 { v23.h }[0], [x28], #0x2\n"
+    "tbz x0, #0, 75f\n"
+    "ld1 { v23.b }[2], [x28]\n"
     "b 75f\n"
     "74:"  // Oddments: Load (4, 1): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 75f\n"
-    "ld1 { v23.b }[0], [x9]\n"
+    "tbz x0, #0, 75f\n"
+    "ld1 { v23.b }[0], [x28]\n"
     "75:"  // Oddments: Load (4, 1): Bit 2: End
-    "ldr d1, [x3, #0x80]\n"
-    "usubl v23.8h, v23.8b, v7.8b\n"
-    "smlal v10.4s, v23.4h, v0.4h\n"
-    "ldr x27, [x25, #0xd0]\n"
-    "usubl v1.8h, v1.8b, v13.8b\n"
-    "smlal2 v9.4s, v23.8h, v0.8h\n"
-    "add x27, x27, x10\n"
+    "ldr d1, [x23, #0x80]\n"
+    "usubl v23.8h, v23.8b, v9.8b\n"
+    "usubl v1.8h, v1.8b, v14.8b\n"
+    "ldr x6, [x20, #0xd0]\n"
+    "smlal v6.4s, v23.4h, v0.4h\n"
+    "smlal2 v5.4s, v23.8h, v0.8h\n"
+    "add x6, x6, x24\n"
     "smlal v15.4s, v30.4h, v1.4h\n"
-    "smlal2 v20.4s, v30.8h, v1.8h\n"
-    "smlal v18.4s, v26.4h, v1.4h\n"
-    "smlal2 v5.4s, v26.8h, v1.8h\n"
-    "smlal v11.4s, v23.4h, v1.4h\n"
-    "smlal2 v8.4s, v23.8h, v1.8h\n"
-    "tbz x4, #2, 77f\n"
-    "ld1 { v31.s }[0], [x27], #0x4\n"
-    "tbz x4, #1, 76f\n"
-    "ld1 { v31.h }[2], [x27], #0x2\n"
-    "tbz x4, #0, 79f\n"
-    "ld1 { v31.b }[6], [x27]\n"
+    "smlal2 v16.4s, v30.8h, v1.8h\n"
+    "smlal v17.4s, v26.4h, v1.4h\n"
+    "smlal2 v8.4s, v26.8h, v1.8h\n"
+    "smlal v10.4s, v23.4h, v1.4h\n"
+    "smlal2 v7.4s, v23.8h, v1.8h\n"
+    "tbz x0, #2, 77f\n"
+    "ld1 { v31.s }[0], [x6], #0x4\n"
+    "tbz x0, #1, 76f\n"
+    "ld1 { v31.h }[2], [x6], #0x2\n"
+    "tbz x0, #0, 79f\n"
+    "ld1 { v31.b }[6], [x6]\n"
     "b 79f\n"
     "76:"  // Oddments: Load (4, 2): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 79f\n"
-    "ld1 { v31.b }[4], [x27]\n"
+    "tbz x0, #0, 79f\n"
+    "ld1 { v31.b }[4], [x6]\n"
     "b 79f\n"
     "77:"  // Oddments: Load (4, 2): Bit 2: Unset
-    "tbz x4, #1, 78f\n"
-    "ld1 { v31.h }[0], [x27], #0x2\n"
-    "tbz x4, #0, 79f\n"
-    "ld1 { v31.b }[2], [x27]\n"
+    "tbz x0, #1, 78f\n"
+    "ld1 { v31.h }[0], [x6], #0x2\n"
+    "tbz x0, #0, 79f\n"
+    "ld1 { v31.b }[2], [x6]\n"
     "b 79f\n"
     "78:"  // Oddments: Load (4, 2): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 79f\n"
-    "ld1 { v31.b }[0], [x27]\n"
+    "tbz x0, #0, 79f\n"
+    "ld1 { v31.b }[0], [x6]\n"
     "79:"  // Oddments: Load (4, 2): Bit 2: End
-    "ldr d2, [x3, #0x88]\n"
-    "usubl v31.8h, v31.8b, v7.8b\n"
-    "smlal v10.4s, v31.4h, v1.4h\n"
-    "ldr x28, [x25, #0xd8]\n"
-    "usubl v2.8h, v2.8b, v13.8b\n"
-    "smlal2 v9.4s, v31.8h, v1.8h\n"
-    "add x28, x28, x10\n"
+    "ldr d2, [x23, #0x88]\n"
+    "usubl v31.8h, v31.8b, v9.8b\n"
+    "usubl v2.8h, v2.8b, v14.8b\n"
+    "ldr x27, [x20, #0xd8]\n"
+    "smlal v6.4s, v31.4h, v1.4h\n"
+    "smlal2 v5.4s, v31.8h, v1.8h\n"
+    "add x27, x27, x24\n"
     "smlal v15.4s, v26.4h, v2.4h\n"
-    "smlal2 v20.4s, v26.8h, v2.8h\n"
-    "smlal v18.4s, v25.4h, v2.4h\n"
-    "smlal2 v5.4s, v25.8h, v2.8h\n"
-    "smlal v11.4s, v31.4h, v2.4h\n"
-    "smlal2 v8.4s, v31.8h, v2.8h\n"
-    "tbz x4, #2, 81f\n"
-    "ld1 { v30.s }[0], [x28], #0x4\n"
-    "tbz x4, #1, 80f\n"
-    "ld1 { v30.h }[2], [x28], #0x2\n"
-    "tbz x4, #0, 83f\n"
-    "ld1 { v30.b }[6], [x28]\n"
+    "smlal2 v16.4s, v26.8h, v2.8h\n"
+    "smlal v17.4s, v25.4h, v2.4h\n"
+    "smlal2 v8.4s, v25.8h, v2.8h\n"
+    "smlal v10.4s, v31.4h, v2.4h\n"
+    "smlal2 v7.4s, v31.8h, v2.8h\n"
+    "tbz x0, #2, 81f\n"
+    "ld1 { v30.s }[0], [x27], #0x4\n"
+    "tbz x0, #1, 80f\n"
+    "ld1 { v30.h }[2], [x27], #0x2\n"
+    "tbz x0, #0, 83f\n"
+    "ld1 { v30.b }[6], [x27]\n"
     "b 83f\n"
     "80:"  // Oddments: Load (4, 3): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 83f\n"
-    "ld1 { v30.b }[4], [x28]\n"
+    "tbz x0, #0, 83f\n"
+    "ld1 { v30.b }[4], [x27]\n"
     "b 83f\n"
     "81:"  // Oddments: Load (4, 3): Bit 2: Unset
-    "tbz x4, #1, 82f\n"
-    "ld1 { v30.h }[0], [x28], #0x2\n"
-    "tbz x4, #0, 83f\n"
-    "ld1 { v30.b }[2], [x28]\n"
+    "tbz x0, #1, 82f\n"
+    "ld1 { v30.h }[0], [x27], #0x2\n"
+    "tbz x0, #0, 83f\n"
+    "ld1 { v30.b }[2], [x27]\n"
     "b 83f\n"
     "82:"  // Oddments: Load (4, 3): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 83f\n"
-    "ld1 { v30.b }[0], [x28]\n"
+    "tbz x0, #0, 83f\n"
+    "ld1 { v30.b }[0], [x27]\n"
     "83:"  // Oddments: Load (4, 3): Bit 2: End
-    "ldr d3, [x3, #0x90]\n"
-    "usubl v30.8h, v30.8b, v7.8b\n"
-    "smlal v10.4s, v30.4h, v2.4h\n"
-    "ldr x12, [x25, #0xe0]\n"
-    "usubl v3.8h, v3.8b, v13.8b\n"
-    "smlal2 v9.4s, v30.8h, v2.8h\n"
-    "add x12, x12, x10\n"
+    "ldr d3, [x23, #0x90]\n"
+    "usubl v30.8h, v30.8b, v9.8b\n"
+    "usubl v3.8h, v3.8b, v14.8b\n"
+    "ldr x11, [x20, #0xe0]\n"
+    "smlal v6.4s, v30.4h, v2.4h\n"
+    "smlal2 v5.4s, v30.8h, v2.8h\n"
+    "add x11, x11, x24\n"
     "smlal v15.4s, v25.4h, v3.4h\n"
-    "smlal2 v20.4s, v25.8h, v3.8h\n"
-    "smlal v18.4s, v24.4h, v3.4h\n"
-    "smlal2 v5.4s, v24.8h, v3.8h\n"
-    "smlal v11.4s, v30.4h, v3.4h\n"
-    "smlal2 v8.4s, v30.8h, v3.8h\n"
-    "tbz x4, #2, 85f\n"
-    "ld1 { v28.s }[0], [x12], #0x4\n"
-    "tbz x4, #1, 84f\n"
-    "ld1 { v28.h }[2], [x12], #0x2\n"
-    "tbz x4, #0, 87f\n"
-    "ld1 { v28.b }[6], [x12]\n"
+    "smlal2 v16.4s, v25.8h, v3.8h\n"
+    "smlal v17.4s, v24.4h, v3.4h\n"
+    "smlal2 v8.4s, v24.8h, v3.8h\n"
+    "smlal v10.4s, v30.4h, v3.4h\n"
+    "smlal2 v7.4s, v30.8h, v3.8h\n"
+    "tbz x0, #2, 85f\n"
+    "ld1 { v28.s }[0], [x11], #0x4\n"
+    "tbz x0, #1, 84f\n"
+    "ld1 { v28.h }[2], [x11], #0x2\n"
+    "tbz x0, #0, 87f\n"
+    "ld1 { v28.b }[6], [x11]\n"
     "b 87f\n"
     "84:"  // Oddments: Load (4, 4): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 87f\n"
-    "ld1 { v28.b }[4], [x12]\n"
+    "tbz x0, #0, 87f\n"
+    "ld1 { v28.b }[4], [x11]\n"
     "b 87f\n"
     "85:"  // Oddments: Load (4, 4): Bit 2: Unset
-    "tbz x4, #1, 86f\n"
-    "ld1 { v28.h }[0], [x12], #0x2\n"
-    "tbz x4, #0, 87f\n"
-    "ld1 { v28.b }[2], [x12]\n"
+    "tbz x0, #1, 86f\n"
+    "ld1 { v28.h }[0], [x11], #0x2\n"
+    "tbz x0, #0, 87f\n"
+    "ld1 { v28.b }[2], [x11]\n"
     "b 87f\n"
     "86:"  // Oddments: Load (4, 4): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 87f\n"
-    "ld1 { v28.b }[0], [x12]\n"
+    "tbz x0, #0, 87f\n"
+    "ld1 { v28.b }[0], [x11]\n"
     "87:"  // Oddments: Load (4, 4): Bit 2: End
-    "ldr d4, [x3, #0x98]\n"
-    "usubl v28.8h, v28.8b, v7.8b\n"
-    "smlal v10.4s, v28.4h, v3.4h\n"
-    "ldr x7, [x25, #0xe8]\n"
-    "usubl v4.8h, v4.8b, v13.8b\n"
-    "smlal2 v9.4s, v28.8h, v3.8h\n"
-    "add x7, x7, x10\n"
+    "ldr d4, [x23, #0x98]\n"
+    "usubl v28.8h, v28.8b, v9.8b\n"
+    "usubl v4.8h, v4.8b, v14.8b\n"
+    "ldr x17, [x20, #0xe8]\n"
+    "smlal v6.4s, v28.4h, v3.4h\n"
+    "smlal2 v5.4s, v28.8h, v3.8h\n"
+    "add x17, x17, x24\n"
     "smlal v15.4s, v24.4h, v4.4h\n"
-    "smlal2 v20.4s, v24.8h, v4.8h\n"
-    "smlal v18.4s, v22.4h, v4.4h\n"
-    "smlal2 v5.4s, v22.8h, v4.8h\n"
-    "smlal v11.4s, v28.4h, v4.4h\n"
-    "smlal2 v8.4s, v28.8h, v4.8h\n"
-    "tbz x4, #2, 89f\n"
-    "ld1 { v26.s }[0], [x7], #0x4\n"
-    "tbz x4, #1, 88f\n"
-    "ld1 { v26.h }[2], [x7], #0x2\n"
-    "tbz x4, #0, 91f\n"
-    "ld1 { v26.b }[6], [x7]\n"
+    "smlal2 v16.4s, v24.8h, v4.8h\n"
+    "smlal v17.4s, v22.4h, v4.4h\n"
+    "smlal2 v8.4s, v22.8h, v4.8h\n"
+    "smlal v10.4s, v28.4h, v4.4h\n"
+    "smlal2 v7.4s, v28.8h, v4.8h\n"
+    "tbz x0, #2, 89f\n"
+    "ld1 { v26.s }[0], [x17], #0x4\n"
+    "tbz x0, #1, 88f\n"
+    "ld1 { v26.h }[2], [x17], #0x2\n"
+    "tbz x0, #0, 91f\n"
+    "ld1 { v26.b }[6], [x17]\n"
     "b 91f\n"
     "88:"  // Oddments: Load (4, 5): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 91f\n"
-    "ld1 { v26.b }[4], [x7]\n"
+    "tbz x0, #0, 91f\n"
+    "ld1 { v26.b }[4], [x17]\n"
     "b 91f\n"
     "89:"  // Oddments: Load (4, 5): Bit 2: Unset
-    "tbz x4, #1, 90f\n"
-    "ld1 { v26.h }[0], [x7], #0x2\n"
-    "tbz x4, #0, 91f\n"
-    "ld1 { v26.b }[2], [x7]\n"
+    "tbz x0, #1, 90f\n"
+    "ld1 { v26.h }[0], [x17], #0x2\n"
+    "tbz x0, #0, 91f\n"
+    "ld1 { v26.b }[2], [x17]\n"
     "b 91f\n"
     "90:"  // Oddments: Load (4, 5): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 91f\n"
-    "ld1 { v26.b }[0], [x7]\n"
+    "tbz x0, #0, 91f\n"
+    "ld1 { v26.b }[0], [x17]\n"
     "91:"  // Oddments: Load (4, 5): Bit 2: End
-    "ldr d0, [x3, #0xa0]\n"
-    "usubl v26.8h, v26.8b, v7.8b\n"
-    "smlal v10.4s, v26.4h, v4.4h\n"
-    "ldr x26, [x25, #0xf0]\n"
-    "usubl v0.8h, v0.8b, v13.8b\n"
-    "smlal2 v9.4s, v26.8h, v4.8h\n"
-    "add x26, x26, x10\n"
+    "ldr d0, [x23, #0xa0]\n"
+    "usubl v26.8h, v26.8b, v9.8b\n"
+    "usubl v0.8h, v0.8b, v14.8b\n"
+    "ldr x5, [x20, #0xf0]\n"
+    "smlal v6.4s, v26.4h, v4.4h\n"
+    "smlal2 v5.4s, v26.8h, v4.8h\n"
+    "add x5, x5, x24\n"
     "smlal v15.4s, v27.4h, v0.4h\n"
-    "smlal2 v20.4s, v27.8h, v0.8h\n"
-    "smlal v18.4s, v23.4h, v0.4h\n"
-    "smlal2 v5.4s, v23.8h, v0.8h\n"
-    "tbz x4, #2, 93f\n"
-    "ld1 { v25.s }[0], [x26], #0x4\n"
-    "tbz x4, #1, 92f\n"
-    "ld1 { v25.h }[2], [x26], #0x2\n"
-    "tbz x4, #0, 95f\n"
-    "ld1 { v25.b }[6], [x26]\n"
+    "smlal2 v16.4s, v27.8h, v0.8h\n"
+    "smlal v17.4s, v23.4h, v0.4h\n"
+    "smlal2 v8.4s, v23.8h, v0.8h\n"
+    "tbz x0, #2, 93f\n"
+    "ld1 { v25.s }[0], [x5], #0x4\n"
+    "tbz x0, #1, 92f\n"
+    "ld1 { v25.h }[2], [x5], #0x2\n"
+    "tbz x0, #0, 95f\n"
+    "ld1 { v25.b }[6], [x5]\n"
     "b 95f\n"
     "92:"  // Oddments: Load (5, 0): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 95f\n"
-    "ld1 { v25.b }[4], [x26]\n"
+    "tbz x0, #0, 95f\n"
+    "ld1 { v25.b }[4], [x5]\n"
     "b 95f\n"
     "93:"  // Oddments: Load (5, 0): Bit 2: Unset
-    "tbz x4, #1, 94f\n"
-    "ld1 { v25.h }[0], [x26], #0x2\n"
-    "tbz x4, #0, 95f\n"
-    "ld1 { v25.b }[2], [x26]\n"
+    "tbz x0, #1, 94f\n"
+    "ld1 { v25.h }[0], [x5], #0x2\n"
+    "tbz x0, #0, 95f\n"
+    "ld1 { v25.b }[2], [x5]\n"
     "b 95f\n"
     "94:"  // Oddments: Load (5, 0): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 95f\n"
-    "ld1 { v25.b }[0], [x26]\n"
+    "tbz x0, #0, 95f\n"
+    "ld1 { v25.b }[0], [x5]\n"
     "95:"  // Oddments: Load (5, 0): Bit 2: End
-    "ldr x23, [x25, #0xf8]\n"
-    "usubl v25.8h, v25.8b, v7.8b\n"
-    "smlal v11.4s, v25.4h, v0.4h\n"
-    "smlal2 v8.4s, v25.8h, v0.8h\n"
-    "add x23, x23, x10\n"
-    "tbz x4, #2, 97f\n"
-    "ld1 { v24.s }[0], [x23], #0x4\n"
-    "tbz x4, #1, 96f\n"
-    "ld1 { v24.h }[2], [x23], #0x2\n"
-    "tbz x4, #0, 99f\n"
-    "ld1 { v24.b }[6], [x23]\n"
+    "usubl v25.8h, v25.8b, v9.8b\n"
+    "ldr x25, [x20, #0xf8]\n"
+    "smlal v10.4s, v25.4h, v0.4h\n"
+    "smlal2 v7.4s, v25.8h, v0.8h\n"
+    "add x25, x25, x24\n"
+    "tbz x0, #2, 97f\n"
+    "ld1 { v24.s }[0], [x25], #0x4\n"
+    "tbz x0, #1, 96f\n"
+    "ld1 { v24.h }[2], [x25], #0x2\n"
+    "tbz x0, #0, 99f\n"
+    "ld1 { v24.b }[6], [x25]\n"
     "b 99f\n"
     "96:"  // Oddments: Load (5, 1): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 99f\n"
-    "ld1 { v24.b }[4], [x23]\n"
+    "tbz x0, #0, 99f\n"
+    "ld1 { v24.b }[4], [x25]\n"
     "b 99f\n"
     "97:"  // Oddments: Load (5, 1): Bit 2: Unset
-    "tbz x4, #1, 98f\n"
-    "ld1 { v24.h }[0], [x23], #0x2\n"
-    "tbz x4, #0, 99f\n"
-    "ld1 { v24.b }[2], [x23]\n"
+    "tbz x0, #1, 98f\n"
+    "ld1 { v24.h }[0], [x25], #0x2\n"
+    "tbz x0, #0, 99f\n"
+    "ld1 { v24.b }[2], [x25]\n"
     "b 99f\n"
     "98:"  // Oddments: Load (5, 1): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 99f\n"
-    "ld1 { v24.b }[0], [x23]\n"
+    "tbz x0, #0, 99f\n"
+    "ld1 { v24.b }[0], [x25]\n"
     "99:"  // Oddments: Load (5, 1): Bit 2: End
-    "ldr d1, [x3, #0xa8]\n"
-    "usubl v24.8h, v24.8b, v7.8b\n"
-    "smlal v10.4s, v24.4h, v0.4h\n"
-    "ldr x22, [x25, #0x100]\n"
-    "usubl v1.8h, v1.8b, v13.8b\n"
-    "smlal2 v9.4s, v24.8h, v0.8h\n"
-    "add x22, x22, x10\n"
+    "ldr d1, [x23, #0xa8]\n"
+    "usubl v24.8h, v24.8b, v9.8b\n"
+    "usubl v1.8h, v1.8b, v14.8b\n"
+    "ldr x26, [x20, #0x100]\n"
+    "smlal v6.4s, v24.4h, v0.4h\n"
+    "smlal2 v5.4s, v24.8h, v0.8h\n"
+    "add x26, x26, x24\n"
     "smlal v15.4s, v23.4h, v1.4h\n"
-    "smlal2 v20.4s, v23.8h, v1.8h\n"
-    "smlal v18.4s, v31.4h, v1.4h\n"
-    "smlal2 v5.4s, v31.8h, v1.8h\n"
-    "smlal v11.4s, v24.4h, v1.4h\n"
-    "smlal2 v8.4s, v24.8h, v1.8h\n"
-    "tbz x4, #2, 101f\n"
-    "ld1 { v27.s }[0], [x22], #0x4\n"
-    "tbz x4, #1, 100f\n"
-    "ld1 { v27.h }[2], [x22], #0x2\n"
-    "tbz x4, #0, 103f\n"
-    "ld1 { v27.b }[6], [x22]\n"
+    "smlal2 v16.4s, v23.8h, v1.8h\n"
+    "smlal v17.4s, v31.4h, v1.4h\n"
+    "smlal2 v8.4s, v31.8h, v1.8h\n"
+    "smlal v10.4s, v24.4h, v1.4h\n"
+    "smlal2 v7.4s, v24.8h, v1.8h\n"
+    "tbz x0, #2, 101f\n"
+    "ld1 { v27.s }[0], [x26], #0x4\n"
+    "tbz x0, #1, 100f\n"
+    "ld1 { v27.h }[2], [x26], #0x2\n"
+    "tbz x0, #0, 103f\n"
+    "ld1 { v27.b }[6], [x26]\n"
     "b 103f\n"
     "100:"  // Oddments: Load (5, 2): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 103f\n"
-    "ld1 { v27.b }[4], [x22]\n"
+    "tbz x0, #0, 103f\n"
+    "ld1 { v27.b }[4], [x26]\n"
     "b 103f\n"
     "101:"  // Oddments: Load (5, 2): Bit 2: Unset
-    "tbz x4, #1, 102f\n"
-    "ld1 { v27.h }[0], [x22], #0x2\n"
-    "tbz x4, #0, 103f\n"
-    "ld1 { v27.b }[2], [x22]\n"
+    "tbz x0, #1, 102f\n"
+    "ld1 { v27.h }[0], [x26], #0x2\n"
+    "tbz x0, #0, 103f\n"
+    "ld1 { v27.b }[2], [x26]\n"
     "b 103f\n"
     "102:"  // Oddments: Load (5, 2): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 103f\n"
-    "ld1 { v27.b }[0], [x22]\n"
+    "tbz x0, #0, 103f\n"
+    "ld1 { v27.b }[0], [x26]\n"
     "103:"  // Oddments: Load (5, 2): Bit 2: End
-    "ldr d2, [x3, #0xb0]\n"
-    "usubl v27.8h, v27.8b, v7.8b\n"
-    "smlal v10.4s, v27.4h, v1.4h\n"
-    "ldr x20, [x25, #0x108]\n"
-    "usubl v2.8h, v2.8b, v13.8b\n"
-    "smlal2 v9.4s, v27.8h, v1.8h\n"
-    "add x20, x20, x10\n"
+    "ldr d2, [x23, #0xb0]\n"
+    "usubl v27.8h, v27.8b, v9.8b\n"
+    "usubl v2.8h, v2.8b, v14.8b\n"
+    "ldr x12, [x20, #0x108]\n"
+    "smlal v6.4s, v27.4h, v1.4h\n"
+    "smlal2 v5.4s, v27.8h, v1.8h\n"
+    "add x12, x12, x24\n"
     "smlal v15.4s, v31.4h, v2.4h\n"
-    "smlal2 v20.4s, v31.8h, v2.8h\n"
-    "smlal v18.4s, v30.4h, v2.4h\n"
-    "smlal2 v5.4s, v30.8h, v2.8h\n"
-    "smlal v11.4s, v27.4h, v2.4h\n"
-    "smlal2 v8.4s, v27.8h, v2.8h\n"
-    "tbz x4, #2, 105f\n"
-    "ld1 { v25.s }[0], [x20], #0x4\n"
-    "tbz x4, #1, 104f\n"
-    "ld1 { v25.h }[2], [x20], #0x2\n"
-    "tbz x4, #0, 107f\n"
-    "ld1 { v25.b }[6], [x20]\n"
+    "smlal2 v16.4s, v31.8h, v2.8h\n"
+    "smlal v17.4s, v30.4h, v2.4h\n"
+    "smlal2 v8.4s, v30.8h, v2.8h\n"
+    "smlal v10.4s, v27.4h, v2.4h\n"
+    "smlal2 v7.4s, v27.8h, v2.8h\n"
+    "tbz x0, #2, 105f\n"
+    "ld1 { v25.s }[0], [x12], #0x4\n"
+    "tbz x0, #1, 104f\n"
+    "ld1 { v25.h }[2], [x12], #0x2\n"
+    "tbz x0, #0, 107f\n"
+    "ld1 { v25.b }[6], [x12]\n"
     "b 107f\n"
     "104:"  // Oddments: Load (5, 3): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 107f\n"
-    "ld1 { v25.b }[4], [x20]\n"
+    "tbz x0, #0, 107f\n"
+    "ld1 { v25.b }[4], [x12]\n"
     "b 107f\n"
     "105:"  // Oddments: Load (5, 3): Bit 2: Unset
-    "tbz x4, #1, 106f\n"
-    "ld1 { v25.h }[0], [x20], #0x2\n"
-    "tbz x4, #0, 107f\n"
-    "ld1 { v25.b }[2], [x20]\n"
+    "tbz x0, #1, 106f\n"
+    "ld1 { v25.h }[0], [x12], #0x2\n"
+    "tbz x0, #0, 107f\n"
+    "ld1 { v25.b }[2], [x12]\n"
     "b 107f\n"
     "106:"  // Oddments: Load (5, 3): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 107f\n"
-    "ld1 { v25.b }[0], [x20]\n"
+    "tbz x0, #0, 107f\n"
+    "ld1 { v25.b }[0], [x12]\n"
     "107:"  // Oddments: Load (5, 3): Bit 2: End
-    "ldr d3, [x3, #0xb8]\n"
-    "usubl v25.8h, v25.8b, v7.8b\n"
-    "smlal v10.4s, v25.4h, v2.4h\n"
-    "ldr x13, [x25, #0x110]\n"
-    "usubl v3.8h, v3.8b, v13.8b\n"
-    "smlal2 v9.4s, v25.8h, v2.8h\n"
-    "add x13, x13, x10\n"
+    "ldr d3, [x23, #0xb8]\n"
+    "usubl v25.8h, v25.8b, v9.8b\n"
+    "usubl v3.8h, v3.8b, v14.8b\n"
+    "ldr x14, [x20, #0x110]\n"
+    "smlal v6.4s, v25.4h, v2.4h\n"
+    "smlal2 v5.4s, v25.8h, v2.8h\n"
+    "add x14, x14, x24\n"
     "smlal v15.4s, v30.4h, v3.4h\n"
-    "smlal2 v20.4s, v30.8h, v3.8h\n"
-    "smlal v18.4s, v28.4h, v3.4h\n"
-    "smlal2 v5.4s, v28.8h, v3.8h\n"
-    "smlal v11.4s, v25.4h, v3.4h\n"
-    "smlal2 v8.4s, v25.8h, v3.8h\n"
-    "tbz x4, #2, 109f\n"
-    "ld1 { v24.s }[0], [x13], #0x4\n"
-    "tbz x4, #1, 108f\n"
-    "ld1 { v24.h }[2], [x13], #0x2\n"
-    "tbz x4, #0, 111f\n"
-    "ld1 { v24.b }[6], [x13]\n"
+    "smlal2 v16.4s, v30.8h, v3.8h\n"
+    "smlal v17.4s, v28.4h, v3.4h\n"
+    "smlal2 v8.4s, v28.8h, v3.8h\n"
+    "smlal v10.4s, v25.4h, v3.4h\n"
+    "smlal2 v7.4s, v25.8h, v3.8h\n"
+    "tbz x0, #2, 109f\n"
+    "ld1 { v24.s }[0], [x14], #0x4\n"
+    "tbz x0, #1, 108f\n"
+    "ld1 { v24.h }[2], [x14], #0x2\n"
+    "tbz x0, #0, 111f\n"
+    "ld1 { v24.b }[6], [x14]\n"
     "b 111f\n"
     "108:"  // Oddments: Load (5, 4): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 111f\n"
-    "ld1 { v24.b }[4], [x13]\n"
+    "tbz x0, #0, 111f\n"
+    "ld1 { v24.b }[4], [x14]\n"
     "b 111f\n"
     "109:"  // Oddments: Load (5, 4): Bit 2: Unset
-    "tbz x4, #1, 110f\n"
-    "ld1 { v24.h }[0], [x13], #0x2\n"
-    "tbz x4, #0, 111f\n"
-    "ld1 { v24.b }[2], [x13]\n"
+    "tbz x0, #1, 110f\n"
+    "ld1 { v24.h }[0], [x14], #0x2\n"
+    "tbz x0, #0, 111f\n"
+    "ld1 { v24.b }[2], [x14]\n"
     "b 111f\n"
     "110:"  // Oddments: Load (5, 4): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 111f\n"
-    "ld1 { v24.b }[0], [x13]\n"
+    "tbz x0, #0, 111f\n"
+    "ld1 { v24.b }[0], [x14]\n"
     "111:"  // Oddments: Load (5, 4): Bit 2: End
-    "ldr d4, [x3, #0xc0]\n"
-    "usubl v24.8h, v24.8b, v7.8b\n"
-    "smlal v10.4s, v24.4h, v3.4h\n"
-    "ldr x21, [x25, #0x118]\n"
-    "usubl v4.8h, v4.8b, v13.8b\n"
-    "smlal2 v9.4s, v24.8h, v3.8h\n"
-    "add x21, x21, x10\n"
+    "ldr d4, [x23, #0xc0]\n"
+    "usubl v24.8h, v24.8b, v9.8b\n"
+    "usubl v4.8h, v4.8b, v14.8b\n"
+    "ldr x21, [x20, #0x118]\n"
+    "smlal v6.4s, v24.4h, v3.4h\n"
+    "smlal2 v5.4s, v24.8h, v3.8h\n"
+    "add x21, x21, x24\n"
     "smlal v15.4s, v28.4h, v4.4h\n"
-    "smlal2 v20.4s, v28.8h, v4.8h\n"
-    "smlal v18.4s, v26.4h, v4.4h\n"
-    "smlal2 v5.4s, v26.8h, v4.8h\n"
-    "smlal v11.4s, v24.4h, v4.4h\n"
-    "smlal2 v8.4s, v24.8h, v4.8h\n"
-    "tbz x4, #2, 113f\n"
+    "smlal2 v16.4s, v28.8h, v4.8h\n"
+    "smlal v17.4s, v26.4h, v4.4h\n"
+    "smlal2 v8.4s, v26.8h, v4.8h\n"
+    "smlal v10.4s, v24.4h, v4.4h\n"
+    "smlal2 v7.4s, v24.8h, v4.8h\n"
+    "tbz x0, #2, 113f\n"
     "ld1 { v27.s }[0], [x21], #0x4\n"
-    "tbz x4, #1, 112f\n"
+    "tbz x0, #1, 112f\n"
     "ld1 { v27.h }[2], [x21], #0x2\n"
-    "tbz x4, #0, 115f\n"
+    "tbz x0, #0, 115f\n"
     "ld1 { v27.b }[6], [x21]\n"
     "b 115f\n"
     "112:"  // Oddments: Load (5, 5): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 115f\n"
+    "tbz x0, #0, 115f\n"
     "ld1 { v27.b }[4], [x21]\n"
     "b 115f\n"
     "113:"  // Oddments: Load (5, 5): Bit 2: Unset
-    "tbz x4, #1, 114f\n"
+    "tbz x0, #1, 114f\n"
     "ld1 { v27.h }[0], [x21], #0x2\n"
-    "tbz x4, #0, 115f\n"
+    "tbz x0, #0, 115f\n"
     "ld1 { v27.b }[2], [x21]\n"
     "b 115f\n"
     "114:"  // Oddments: Load (5, 5): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 115f\n"
+    "tbz x0, #0, 115f\n"
     "ld1 { v27.b }[0], [x21]\n"
     "115:"  // Oddments: Load (5, 5): Bit 2: End
-    "usubl v27.8h, v27.8b, v7.8b\n"
-    "smlal v10.4s, v27.4h, v4.4h\n"
-    "smlal2 v9.4s, v27.8h, v4.8h\n"
-    "tbz x4, #2, 117f\n"
-    "ld1 { v6.4s }, [x2], #0x10\n"
-    "ld1 { v21.4s }, [x5], #0x10\n"
-    "tbz x4, #1, 116f\n"
-    "ld1 { v17.d }[0], [x2], #0x8\n"
-    "ld1 { v14.d }[0], [x5], #0x8\n"
-    "tbz x4, #0, 119f\n"
-    "ld1 { v17.s }[2], [x2]\n"
-    "ld1 { v14.s }[2], [x5]\n"
+    "usubl v27.8h, v27.8b, v9.8b\n"
+    "smlal v6.4s, v27.4h, v4.4h\n"
+    "smlal2 v5.4s, v27.8h, v4.8h\n"
+    "tbz x0, #2, 117f\n"
+    "ld1 { v12.4s }, [x10], #0x10\n"
+    "ld1 { v19.4s }, [x1], #0x10\n"
+    "tbz x0, #1, 116f\n"
+    "ld1 { v20.d }[0], [x10], #0x8\n"
+    "ld1 { v29.d }[0], [x1], #0x8\n"
+    "tbz x0, #0, 119f\n"
+    "ld1 { v20.s }[2], [x10]\n"
+    "ld1 { v29.s }[2], [x1]\n"
     "b 119f\n"
     "116:"  // Oddments: Load requant params: Bit 2: Bit 1: Unset
-    "tbz x4, #0, 119f\n"
-    "ld1 { v17.s }[0], [x2]\n"
-    "ld1 { v14.s }[0], [x5]\n"
+    "tbz x0, #0, 119f\n"
+    "ld1 { v20.s }[0], [x10]\n"
+    "ld1 { v29.s }[0], [x1]\n"
     "b 119f\n"
     "117:"  // Oddments: Load requant params: Bit 2: Unset
-    "tbz x4, #1, 118f\n"
-    "ld1 { v6.d }[0], [x2], #0x8\n"
-    "ld1 { v21.d }[0], [x5], #0x8\n"
-    "tbz x4, #0, 119f\n"
-    "ld1 { v6.s }[2], [x2]\n"
-    "ld1 { v21.s }[2], [x5]\n"
+    "tbz x0, #1, 118f\n"
+    "ld1 { v12.d }[0], [x10], #0x8\n"
+    "ld1 { v19.d }[0], [x1], #0x8\n"
+    "tbz x0, #0, 119f\n"
+    "ld1 { v12.s }[2], [x10]\n"
+    "ld1 { v19.s }[2], [x1]\n"
     "b 119f\n"
     "118:"  // Oddments: Load requant params: Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 119f\n"
-    "ld1 { v6.s }[0], [x2]\n"
-    "ld1 { v21.s }[0], [x5]\n"
+    "tbz x0, #0, 119f\n"
+    "ld1 { v12.s }[0], [x10]\n"
+    "ld1 { v19.s }[0], [x1]\n"
     "119:"  // Oddments: Load requant params: Bit 2: End
-    "sqrdmulh v15.4s, v15.4s, v6.4s\n"
-    "add x17, x17, x1\n"
-    "sqrdmulh v20.4s, v20.4s, v17.4s\n"
-    "add x16, x16, x1\n"
-    "sqrdmulh v18.4s, v18.4s, v6.4s\n"
-    "add x6, x6, x1\n"
-    "sqrdmulh v5.4s, v5.4s, v17.4s\n"
-    "add x8, x8, x1\n"
-    "sqrdmulh v11.4s, v11.4s, v6.4s\n"
-    "and v1.16b, v15.16b, v21.16b\n"
-    "sshr v1.4s, v1.4s, #0x1f\n"
-    "and v29.16b, v20.16b, v14.16b\n"
-    "and v3.16b, v18.16b, v21.16b\n"
-    "sshr v29.4s, v29.4s, #0x1f\n"
-    "and v2.16b, v5.16b, v14.16b\n"
-    "and v0.16b, v11.16b, v21.16b\n"
-    "sshr v3.4s, v3.4s, #0x1f\n"
-    "sqrdmulh v8.4s, v8.4s, v17.4s\n"
+    "sqdmulh v15.4s, v15.4s, v12.4s\n"
+    "sqdmulh v17.4s, v17.4s, v12.4s\n"
+    "add x16, x16, x22\n"
+    "add x8, x8, x22\n"
+    "sqdmulh v10.4s, v10.4s, v12.4s\n"
+    "sqdmulh v6.4s, v6.4s, v12.4s\n"
+    "add x4, x4, x22\n"
+    "add x7, x7, x22\n"
+    "and v23.16b, v15.16b, v19.16b\n"
+    "sqdmulh v16.4s, v16.4s, v20.4s\n"
+    "and v22.16b, v17.16b, v19.16b\n"
+    "sqdmulh v8.4s, v8.4s, v20.4s\n"
+    "and v21.16b, v10.16b, v19.16b\n"
+    "sqdmulh v7.4s, v7.4s, v20.4s\n"
+    "and v26.16b, v6.16b, v19.16b\n"
+    "sqdmulh v5.4s, v5.4s, v20.4s\n"
+    "sshr v23.4s, v23.4s, #0x1f\n"
+    "and v4.16b, v16.16b, v29.16b\n"
+    "sshr v22.4s, v22.4s, #0x1f\n"
+    "and v2.16b, v8.16b, v29.16b\n"
+    "sshr v21.4s, v21.4s, #0x1f\n"
+    "and v3.16b, v7.16b, v29.16b\n"
+    "sshr v26.4s, v26.4s, #0x1f\n"
+    "and v25.16b, v5.16b, v29.16b\n"
+    "sqadd v15.4s, v15.4s, v23.4s\n"
+    "sshr v4.4s, v4.4s, #0x1f\n"
+    "sqadd v17.4s, v17.4s, v22.4s\n"
     "sshr v2.4s, v2.4s, #0x1f\n"
-    "sqadd v15.4s, v15.4s, v1.4s\n"
-    "sqrdmulh v10.4s, v10.4s, v6.4s\n"
-    "sshr v0.4s, v0.4s, #0x1f\n"
-    "sqrdmulh v9.4s, v9.4s, v17.4s\n"
-    "sqadd v20.4s, v20.4s, v29.4s\n"
-    "sqadd v18.4s, v18.4s, v3.4s\n"
-    "srshl v15.4s, v15.4s, v21.4s\n"
-    "sqadd v5.4s, v5.4s, v2.4s\n"
-    "srshl v20.4s, v20.4s, v14.4s\n"
-    "srshl v18.4s, v18.4s, v21.4s\n"
-    "add v15.4s, v15.4s, v19.4s\n"
-    "srshl v5.4s, v5.4s, v14.4s\n"
-    "add v20.4s, v20.4s, v19.4s\n"
-    "smin v15.4s, v15.4s, v12.4s\n"
-    "add v18.4s, v18.4s, v19.4s\n"
-    "smin v20.4s, v20.4s, v12.4s\n"
-    "smax v15.4s, v15.4s, v16.4s\n"
-    "smin v18.4s, v18.4s, v12.4s\n"
-    "smax v20.4s, v20.4s, v16.4s\n"
-    "add v5.4s, v5.4s, v19.4s\n"
-    "smax v18.4s, v18.4s, v16.4s\n"
-    "uzp1 v15.16b, v15.16b, v20.16b\n"
-    "smin v5.4s, v5.4s, v12.4s\n"
+    "sqadd v10.4s, v10.4s, v21.4s\n"
+    "sshr v3.4s, v3.4s, #0x1f\n"
+    "sqadd v6.4s, v6.4s, v26.4s\n"
+    "sshr v25.4s, v25.4s, #0x1f\n"
+    "srshl v15.4s, v15.4s, v19.4s\n"
+    "sqadd v16.4s, v16.4s, v4.4s\n"
+    "srshl v17.4s, v17.4s, v19.4s\n"
+    "sqadd v8.4s, v8.4s, v2.4s\n"
+    "srshl v10.4s, v10.4s, v19.4s\n"
+    "sqadd v7.4s, v7.4s, v3.4s\n"
+    "srshl v6.4s, v6.4s, v19.4s\n"
+    "sqadd v5.4s, v5.4s, v25.4s\n"
+    "srshl v16.4s, v16.4s, v29.4s\n"
+    "sqxtn v15.4h, v15.4s\n"
+    "srshl v8.4s, v8.4s, v29.4s\n"
+    "sqxtn v17.4h, v17.4s\n"
+    "srshl v7.4s, v7.4s, v29.4s\n"
+    "sqxtn v10.4h, v10.4s\n"
+    "srshl v5.4s, v5.4s, v29.4s\n"
+    "sqxtn v6.4h, v6.4s\n"
+    "sqxtn2 v15.8h, v16.4s\n"
+    "sqxtn2 v17.8h, v8.4s\n"
+    "sqxtn2 v10.8h, v7.4s\n"
+    "sqxtn2 v6.8h, v5.4s\n"
+    "sqadd v15.8h, v15.8h, v18.8h\n"
+    "sqadd v17.8h, v17.8h, v18.8h\n"
+    "sqadd v10.8h, v10.8h, v18.8h\n"
+    "sqadd v6.8h, v6.8h, v18.8h\n"
+    "smax v15.8h, v15.8h, v11.8h\n"
+    "smax v17.8h, v17.8h, v11.8h\n"
+    "smax v10.8h, v10.8h, v11.8h\n"
+    "smax v6.8h, v6.8h, v11.8h\n"
+    "smin v15.8h, v15.8h, v13.8h\n"
+    "smin v17.8h, v17.8h, v13.8h\n"
+    "smin v10.8h, v10.8h, v13.8h\n"
+    "smin v6.8h, v6.8h, v13.8h\n"
     "uzp1 v15.16b, v15.16b, v15.16b\n"
-    "sqadd v11.4s, v11.4s, v0.4s\n"
-    "smax v5.4s, v5.4s, v16.4s\n"
-    "and v27.16b, v8.16b, v14.16b\n"
-    "sshr v27.4s, v27.4s, #0x1f\n"
-    "uzp1 v18.16b, v18.16b, v5.16b\n"
-    "srshl v11.4s, v11.4s, v21.4s\n"
-    "and v30.16b, v10.16b, v21.16b\n"
-    "sshr v30.4s, v30.4s, #0x1f\n"
-    "uzp1 v18.16b, v18.16b, v18.16b\n"
-    "add v11.4s, v11.4s, v19.4s\n"
-    "sqadd v8.4s, v8.4s, v27.4s\n"
-    "and v6.16b, v9.16b, v14.16b\n"
-    "sshr v6.4s, v6.4s, #0x1f\n"
-    "smin v11.4s, v11.4s, v12.4s\n"
-    "srshl v8.4s, v8.4s, v14.4s\n"
-    "sqadd v10.4s, v10.4s, v30.4s\n"
-    "smax v11.4s, v11.4s, v16.4s\n"
-    "add v8.4s, v8.4s, v19.4s\n"
-    "srshl v10.4s, v10.4s, v21.4s\n"
-    "sqadd v9.4s, v9.4s, v6.4s\n"
-    "smin v8.4s, v8.4s, v12.4s\n"
-    "add v10.4s, v10.4s, v19.4s\n"
-    "srshl v9.4s, v9.4s, v14.4s\n"
-    "smax v8.4s, v8.4s, v16.4s\n"
-    "smin v10.4s, v10.4s, v12.4s\n"
-    "uzp1 v11.16b, v11.16b, v8.16b\n"
-    "add v9.4s, v9.4s, v19.4s\n"
-    "uzp1 v11.16b, v11.16b, v11.16b\n"
-    "smax v10.4s, v10.4s, v16.4s\n"
-    "smin v9.4s, v9.4s, v12.4s\n"
-    "smax v9.4s, v9.4s, v16.4s\n"
-    "uzp1 v10.16b, v10.16b, v9.16b\n"
+    "uzp1 v17.16b, v17.16b, v17.16b\n"
     "uzp1 v10.16b, v10.16b, v10.16b\n"
-    "tbz x4, #2, 121f\n"
-    "st1 { v15.s }[0], [x17], #0x4\n"
-    "st1 { v18.s }[0], [x16], #0x4\n"
-    "st1 { v11.s }[0], [x6], #0x4\n"
-    "st1 { v10.s }[0], [x8], #0x4\n"
-    "tbz x4, #1, 120f\n"
-    "st1 { v15.h }[2], [x17], #0x2\n"
-    "st1 { v18.h }[2], [x16], #0x2\n"
-    "st1 { v11.h }[2], [x6], #0x2\n"
-    "st1 { v10.h }[2], [x8], #0x2\n"
-    "tbz x4, #0, 123f\n"
-    "st1 { v15.b }[6], [x17], #0x1\n"
-    "st1 { v18.b }[6], [x16], #0x1\n"
-    "st1 { v11.b }[6], [x6], #0x1\n"
-    "st1 { v10.b }[6], [x8], #0x1\n"
+    "uzp1 v6.16b, v6.16b, v6.16b\n"
+    "tbz x0, #2, 121f\n"
+    "st1 { v15.s }[0], [x16], #0x4\n"
+    "st1 { v17.s }[0], [x8], #0x4\n"
+    "st1 { v10.s }[0], [x4], #0x4\n"
+    "st1 { v6.s }[0], [x7], #0x4\n"
+    "tbz x0, #1, 120f\n"
+    "st1 { v15.h }[2], [x16], #0x2\n"
+    "st1 { v17.h }[2], [x8], #0x2\n"
+    "st1 { v10.h }[2], [x4], #0x2\n"
+    "st1 { v6.h }[2], [x7], #0x2\n"
+    "tbz x0, #0, 123f\n"
+    "st1 { v15.b }[6], [x16], #0x1\n"
+    "st1 { v17.b }[6], [x8], #0x1\n"
+    "st1 { v10.b }[6], [x4], #0x1\n"
+    "st1 { v6.b }[6], [x7], #0x1\n"
     "b 123f\n"
     "120:"  // Oddments: Bit 2: Bit 1: Unset
-    "tbz x4, #0, 123f\n"
-    "st1 { v15.b }[4], [x17], #0x1\n"
-    "st1 { v18.b }[4], [x16], #0x1\n"
-    "st1 { v11.b }[4], [x6], #0x1\n"
-    "st1 { v10.b }[4], [x8], #0x1\n"
+    "tbz x0, #0, 123f\n"
+    "st1 { v15.b }[4], [x16], #0x1\n"
+    "st1 { v17.b }[4], [x8], #0x1\n"
+    "st1 { v10.b }[4], [x4], #0x1\n"
+    "st1 { v6.b }[4], [x7], #0x1\n"
     "b 123f\n"
     "121:"  // Oddments: Bit 2: Unset
-    "tbz x4, #1, 122f\n"
-    "st1 { v15.h }[0], [x17], #0x2\n"
-    "st1 { v18.h }[0], [x16], #0x2\n"
-    "st1 { v11.h }[0], [x6], #0x2\n"
-    "st1 { v10.h }[0], [x8], #0x2\n"
-    "tbz x4, #0, 123f\n"
-    "st1 { v15.b }[2], [x17], #0x1\n"
-    "st1 { v18.b }[2], [x16], #0x1\n"
-    "st1 { v11.b }[2], [x6], #0x1\n"
-    "st1 { v10.b }[2], [x8], #0x1\n"
+    "tbz x0, #1, 122f\n"
+    "st1 { v15.h }[0], [x16], #0x2\n"
+    "st1 { v17.h }[0], [x8], #0x2\n"
+    "st1 { v10.h }[0], [x4], #0x2\n"
+    "st1 { v6.h }[0], [x7], #0x2\n"
+    "tbz x0, #0, 123f\n"
+    "st1 { v15.b }[2], [x16], #0x1\n"
+    "st1 { v17.b }[2], [x8], #0x1\n"
+    "st1 { v10.b }[2], [x4], #0x1\n"
+    "st1 { v6.b }[2], [x7], #0x1\n"
     "b 123f\n"
     "122:"  // Oddments: Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 123f\n"
-    "st1 { v15.b }[0], [x17], #0x1\n"
-    "st1 { v18.b }[0], [x16], #0x1\n"
-    "st1 { v11.b }[0], [x6], #0x1\n"
-    "st1 { v10.b }[0], [x8], #0x1\n"
+    "tbz x0, #0, 123f\n"
+    "st1 { v15.b }[0], [x16], #0x1\n"
+    "st1 { v17.b }[0], [x8], #0x1\n"
+    "st1 { v10.b }[0], [x4], #0x1\n"
+    "st1 { v6.b }[0], [x7], #0x1\n"
     "123:"  // Oddments: Bit 2: End
-
     "124:"  // End
-
     :
     : [offsetof_Params_bias] "I" (offsetof(Params, bias)), [offsetof_Params_inptrs] "I" (offsetof(Params, inptrs)), [offsetof_Params_n_channels] "I" (offsetof(Params, n_channels)), [offsetof_Params_outptrs] "I" (offsetof(Params, outptrs)), [offsetof_Params_requant] "I" (offsetof(Params, requant)), [offsetof_Params_requant_muls] "I" (offsetof(Params, requant_muls)), [offsetof_Params_requant_shifts] "I" (offsetof(Params, requant_shifts)), [offsetof_Params_weights] "I" (offsetof(Params, weights)), [offsetof_Requantize32_a_offset] "I" (offsetof(arm_gemm::Requantize32, a_offset)), [offsetof_Requantize32_b_offset] "I" (offsetof(arm_gemm::Requantize32, b_offset)), [offsetof_Requantize32_c_offset] "I" (offsetof(arm_gemm::Requantize32, c_offset)), [offsetof_Requantize32_maxval] "I" (offsetof(arm_gemm::Requantize32, maxval)), [offsetof_Requantize32_minval] "I" (offsetof(arm_gemm::Requantize32, minval)), [params] "r" (&params)
     : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17", "x19", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8q_nhwc_generic_output9_mla_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8q_nhwc_generic_output9_mla_depthfirst.hpp
index f5459c2..b859978 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8q_nhwc_generic_output9_mla_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8q_nhwc_generic_output9_mla_depthfirst.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -28,28 +28,23 @@
 
 #pragma once
 
+#if defined(__aarch64__)
+
 namespace arm_conv {
 namespace depthwise {
 
 void a64_u8q_nhwc_generic_output9_mla_depthfirst_impl(const uint8_t *const *const, uint8_t *const *const, const void *, const arm_gemm::Requantize32&, const unsigned int, const unsigned int);
 
-struct a64_u8q_nhwc_generic_output9_mla_depthfirst
+class a64_u8q_nhwc_generic_output9_mla_depthfirst : public GenericDepthfirstKernelStrategy<uint8_t, uint8_t, uint8_t, int32_t>
 {
-  typedef int32_t bias_type;
-  typedef uint8_t input_type;
-  typedef uint8_t weight_type;
-  typedef uint8_t return_type;
+  KernelType kernel = a64_u8q_nhwc_generic_output9_mla_depthfirst_impl;
 
-  typedef void (*kern_type)(const uint8_t *const *const, uint8_t *const *const, const void *, const arm_gemm::Requantize32&, const unsigned int, const unsigned int);
+  public:
+  a64_u8q_nhwc_generic_output9_mla_depthfirst(const CPUInfo *) : GenericDepthfirstKernelStrategy<uint8_t, uint8_t, uint8_t, int32_t>(9, arm_gemm::VLType::None) {}
 
-  constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::None;
-
-  constexpr static unsigned int n_output_points = 9;
-
-  kern_type kernel = a64_u8q_nhwc_generic_output9_mla_depthfirst_impl;
-
-  a64_u8q_nhwc_generic_output9_mla_depthfirst(const CPUInfo *) {}
+  KernelType get_kernel() const override { return kernel; }
 };
 
 }  // namespace depthwise
 }  // namespace arm_conv
+#endif // defined(__aarch64__)
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8q_packed_to_nhwc_3x3_s2_with_multiplier_output2x4_dot_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8q_packed_to_nhwc_3x3_s2_with_multiplier_output2x4_dot_depthfirst.hpp
index e8ac603..134f657 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8q_packed_to_nhwc_3x3_s2_with_multiplier_output2x4_dot_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8q_packed_to_nhwc_3x3_s2_with_multiplier_output2x4_dot_depthfirst.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -28,39 +28,33 @@
 
 #pragma once
 
+#if defined(__aarch64__)
+
 namespace arm_conv {
 namespace depthwise {
 
 void a64_u8q_packed_to_nhwc_3x3_s2_with_multiplier_output2x4_dot_depthfirst_impl(const uint8_t *const *const, uint8_t *const *const, const void *, unsigned int, const arm_gemm::Requantize32&);
 
-struct a64_u8q_packed_to_nhwc_3x3_s2_with_multiplier_output2x4_dot_depthfirst
+struct a64_u8q_packed_to_nhwc_3x3_s2_with_multiplier_output2x4_dot_depthfirst : DepthfirstMultiplierStrategy<uint8_t, uint8_t, uint8_t, int32_t>
 {
-  typedef uint32_t bias_type;
-  typedef uint8_t input_type;
-  typedef uint8_t weight_type;
-  typedef uint8_t return_type;
-
-  typedef void (*kern_type)(const uint8_t *const *const, uint8_t *const *const, const void *, unsigned int, const arm_gemm::Requantize32&);
-
-  constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::None;
-
+  using Parent = DepthfirstMultiplierStrategy<uint8_t, uint8_t, uint8_t, int32_t>;
   constexpr static unsigned int kernel_rows = 3;
   constexpr static unsigned int kernel_cols = 3;
 
   constexpr static unsigned int stride_rows = 2;
   constexpr static unsigned int stride_cols = 2;
 
-  constexpr static unsigned int output_rows = 2;
-  constexpr static unsigned int output_cols = 4;
+  a64_u8q_packed_to_nhwc_3x3_s2_with_multiplier_output2x4_dot_depthfirst(const CPUInfo *)
+  : Parent(2, 4, kernel_rows, kernel_cols, stride_rows, stride_cols)
+  {
+  }
 
-  constexpr static unsigned int input_rows = 5;
-  constexpr static unsigned int input_cols = 9;
-  constexpr static unsigned int input_col_quads = 1;
+  arm_gemm::VLType get_vl_type() const override { return arm_gemm::VLType::None; }
 
-  kern_type kernel = a64_u8q_packed_to_nhwc_3x3_s2_with_multiplier_output2x4_dot_depthfirst_impl;
-
-  a64_u8q_packed_to_nhwc_3x3_s2_with_multiplier_output2x4_dot_depthfirst(const CPUInfo *) {}
+  Parent::KernelType kernel = a64_u8q_packed_to_nhwc_3x3_s2_with_multiplier_output2x4_dot_depthfirst_impl;
+  Parent::KernelType get_kernel(void) const override { return kernel; }
 };
 
 }  // namespace depthwise
 }  // namespace arm_conv
+#endif // defined(__aarch64__)
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8q_packed_to_nhwc_5x5_s1_with_multiplier_output4x2_dot_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8q_packed_to_nhwc_5x5_s1_with_multiplier_output4x2_dot_depthfirst.hpp
index c5e0417..b575a5d 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8q_packed_to_nhwc_5x5_s1_with_multiplier_output4x2_dot_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8q_packed_to_nhwc_5x5_s1_with_multiplier_output4x2_dot_depthfirst.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -28,39 +28,33 @@
 
 #pragma once
 
+#if defined(__aarch64__)
+
 namespace arm_conv {
 namespace depthwise {
 
 void a64_u8q_packed_to_nhwc_5x5_s1_with_multiplier_output4x2_dot_depthfirst_impl(const uint8_t *const *const, uint8_t *const *const, const void *, unsigned int, const arm_gemm::Requantize32&);
 
-struct a64_u8q_packed_to_nhwc_5x5_s1_with_multiplier_output4x2_dot_depthfirst
+struct a64_u8q_packed_to_nhwc_5x5_s1_with_multiplier_output4x2_dot_depthfirst : DepthfirstMultiplierStrategy<uint8_t, uint8_t, uint8_t, int32_t>
 {
-  typedef uint32_t bias_type;
-  typedef uint8_t input_type;
-  typedef uint8_t weight_type;
-  typedef uint8_t return_type;
-
-  typedef void (*kern_type)(const uint8_t *const *const, uint8_t *const *const, const void *, unsigned int, const arm_gemm::Requantize32&);
-
-  constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::None;
-
+  using Parent = DepthfirstMultiplierStrategy<uint8_t, uint8_t, uint8_t, int32_t>;
   constexpr static unsigned int kernel_rows = 5;
   constexpr static unsigned int kernel_cols = 5;
 
   constexpr static unsigned int stride_rows = 1;
   constexpr static unsigned int stride_cols = 1;
 
-  constexpr static unsigned int output_rows = 4;
-  constexpr static unsigned int output_cols = 2;
+  a64_u8q_packed_to_nhwc_5x5_s1_with_multiplier_output4x2_dot_depthfirst(const CPUInfo *)
+  : Parent(4, 2, kernel_rows, kernel_cols, stride_rows, stride_cols)
+  {
+  }
 
-  constexpr static unsigned int input_rows = 8;
-  constexpr static unsigned int input_cols = 6;
-  constexpr static unsigned int input_col_quads = 1;
+  arm_gemm::VLType get_vl_type() const override { return arm_gemm::VLType::None; }
 
-  kern_type kernel = a64_u8q_packed_to_nhwc_5x5_s1_with_multiplier_output4x2_dot_depthfirst_impl;
-
-  a64_u8q_packed_to_nhwc_5x5_s1_with_multiplier_output4x2_dot_depthfirst(const CPUInfo *) {}
+  Parent::KernelType kernel = a64_u8q_packed_to_nhwc_5x5_s1_with_multiplier_output4x2_dot_depthfirst_impl;
+  Parent::KernelType get_kernel(void) const override { return kernel; }
 };
 
 }  // namespace depthwise
 }  // namespace arm_conv
+#endif // defined(__aarch64__)
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8q_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8q_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst.hpp
index 6b52017..13f903b 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8q_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8q_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -28,31 +28,25 @@
 
 #pragma once
 
+#if defined(__aarch64__)
+
 namespace arm_conv {
 namespace depthwise {
 
 void a64_u8q_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst_impl(const uint8_t *const *const, uint8_t *const *const, const uint8_t *, const int32_t *, const unsigned int, const unsigned int, const int32_t *, const int32_t *, const int32_t *, const arm_gemm::Requantize32&);
 
-struct a64_u8q_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst
+struct a64_u8q_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst : GenericDepthfirstMultiplierKernelStrategy<uint8_t, uint8_t, uint8_t, int32_t>
 {
-  typedef int32_t bias_type;
-  typedef uint8_t input_type;
-  typedef uint8_t weight_type;
-  typedef uint8_t return_type;
-
-  typedef void (*kern_type)(const uint8_t *const *const, uint8_t *const *const, const uint8_t *, const int32_t *, const unsigned int, const unsigned int, const int32_t *, const int32_t *, const int32_t *, const arm_gemm::Requantize32&);
-
-  constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::None;
-
-  constexpr static unsigned int output_rows(void) { return 2; };
-  constexpr static unsigned int output_cols(void) { return 8; };
-
-  constexpr static unsigned int output_col_regs(void) { return 2; };
-
-  kern_type kernel = a64_u8q_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst_impl;
-
-  a64_u8q_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst(const CPUInfo *) {}
+  using Parent = GenericDepthfirstMultiplierKernelStrategy<uint8_t, uint8_t, uint8_t, int32_t>;
+  a64_u8q_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst(const CPUInfo *)
+  : Parent(2, 8, arm_gemm::VLType::None)
+  {
+  }
+  Parent::KernelType kernel = a64_u8q_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst_impl;
+  Parent::KernelType get_kernel(void) const override { return kernel; }
 };
 
 }  // namespace depthwise
 }  // namespace arm_conv
+
+#endif  // defined(__aarch64__)
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8qa_nhwc_3x3_s1_output2x2_mla_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8qa_nhwc_3x3_s1_output2x2_mla_depthfirst.hpp
new file mode 100644
index 0000000..2d2b452
--- /dev/null
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8qa_nhwc_3x3_s1_output2x2_mla_depthfirst.hpp
@@ -0,0 +1,63 @@
+/*
+ * Copyright (c) 2022 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "src/core/NEON/kernels/arm_gemm/utils.hpp"
+
+#include "src/core/NEON/kernels/arm_conv/depthwise/interleaves/list.hpp"
+
+#include <cstdint>
+
+#pragma once
+
+#if defined(__aarch64__)
+
+namespace arm_conv {
+namespace depthwise {
+
+void a64_u8qa_nhwc_3x3_s1_output2x2_mla_depthfirst_impl(unsigned int, const uint8_t *const *, const uint8_t *, const int32_t *, const arm_gemm::Requantize32 &, const int32_t *, const int32_t *, uint8_t *const *);
+
+class a64_u8qa_nhwc_3x3_s1_output2x2_mla_depthfirst : public DepthwiseDepthfirstStrategy<uint8_t, uint8_t, uint8_t, int32_t>
+{
+  using Parent = DepthwiseDepthfirstStrategy<uint8_t, uint8_t, uint8_t, int32_t>;
+
+  public:
+  constexpr static unsigned int kernel_rows = 3;
+  constexpr static unsigned int kernel_cols = 3;
+
+  constexpr static unsigned int stride_rows = 1;
+  constexpr static unsigned int stride_cols = 1;
+
+  a64_u8qa_nhwc_3x3_s1_output2x2_mla_depthfirst(const CPUInfo *) : Parent(2, 2, 3, 3, 1, 1) {}
+
+  arm_gemm::VLType get_vl_type(void) const override { return arm_gemm::VLType::None; }
+
+  Parent::KernelType kernel = a64_u8qa_nhwc_3x3_s1_output2x2_mla_depthfirst_impl;
+  Parent::KernelType get_kernel(void) const override { return kernel; }
+  unsigned int get_accumulator_depth_vl(void) const override { return 2; }
+};
+
+}  // namespace depthwise
+}  // namespace arm_conv
+
+#endif  // defined(__aarch64__)
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8qa_nhwc_3x3_s1_output2x2_mla_depthfirst/generic.cpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8qa_nhwc_3x3_s1_output2x2_mla_depthfirst/generic.cpp
new file mode 100644
index 0000000..2410d38
--- /dev/null
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8qa_nhwc_3x3_s1_output2x2_mla_depthfirst/generic.cpp
@@ -0,0 +1,1164 @@
+/*
+ * Copyright (c) 2022 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "arm_gemm.hpp"
+
+#include <cstddef>
+#include <cstdint>
+
+#if defined(__aarch64__)
+
+namespace arm_conv {
+namespace depthwise {
+
+void a64_u8qa_nhwc_3x3_s1_output2x2_mla_depthfirst_impl(
+  const unsigned int n_channels,
+  const uint8_t *const *const inptrs,
+  const uint8_t *const weights,
+  const int32_t *const bias,
+  const arm_gemm::Requantize32 &qp,
+  const int32_t *const requant_muls,
+  const int32_t *const requant_shifts,
+  uint8_t *const *const outptrs
+)
+{
+  struct Params
+  {
+    long unsigned int n_channels;
+    const void *weights;
+    const int32_t *bias;
+    const arm_gemm::Requantize32 *requant;
+    const int32_t *const requant_muls;
+    const int32_t *const requant_shifts;
+    uint8_t *const *const outptrs;
+    const uint8_t *inptrs[16];
+
+    Params(
+      long unsigned int n_channels,
+      const uint8_t *const *inptrs_raw,
+      const void *const weights,
+      const int32_t *const bias,
+      const arm_gemm::Requantize32 &qp,
+      const int32_t *const requant_muls,
+      const int32_t *const requant_shifts,
+      uint8_t *const *outptrs
+    ) : n_channels(n_channels), weights(weights), bias(bias),
+        requant(&qp), requant_muls(requant_muls),
+        requant_shifts(requant_shifts), outptrs(outptrs)
+    {
+      inptrs[0] = inptrs_raw[5];
+      inptrs[1] = inptrs_raw[0];
+      inptrs[2] = inptrs_raw[3];
+      inptrs[3] = inptrs_raw[6];
+      inptrs[4] = inptrs_raw[9];
+      inptrs[5] = inptrs_raw[12];
+      inptrs[6] = inptrs_raw[15];
+      inptrs[7] = inptrs_raw[1];
+      inptrs[8] = inptrs_raw[2];
+      inptrs[9] = inptrs_raw[10];
+      inptrs[10] = inptrs_raw[4];
+      inptrs[11] = inptrs_raw[7];
+      inptrs[12] = inptrs_raw[8];
+      inptrs[13] = inptrs_raw[11];
+      inptrs[14] = inptrs_raw[13];
+      inptrs[15] = inptrs_raw[14];
+
+    }
+  };
+
+  const Params params(n_channels, inptrs, weights, bias, qp,
+                      requant_muls, requant_shifts, outptrs);
+
+  __asm__ __volatile__(
+    "ldr x19, [%x[params], %[offsetof_Params_requant]]\n"
+    "ldr x8, [%x[params], %[offsetof_Params_n_channels]]\n"
+    "add x23, x19, %[offsetof_Requantize32_b_offset]\n"
+    "add x22, x19, %[offsetof_Requantize32_c_offset]\n"
+    "ldr x21, [%x[params], %[offsetof_Params_outptrs]]\n"
+    "add x20, x19, %[offsetof_Requantize32_minval]\n"
+    "add x19, x19, %[offsetof_Requantize32_maxval]\n"
+    "ldr x17, [%x[params], %[offsetof_Params_weights]]\n"
+    "ld1r { v15.16b }, [x23]\n"
+    "ld1r { v13.8h }, [x22]\n"
+    "lsr x16, x8, #0x3\n"
+    "mov x15, #0x0\n"
+    "ld1r { v11.8h }, [x20]\n"
+    "ld1r { v25.8h }, [x19]\n"
+    "mov x14, #0x0\n"
+    "add x13, %x[params], %[offsetof_Params_inptrs]\n"
+    "ldr x12, [%x[params], %[offsetof_Params_requant_muls]]\n"
+    "ldr x11, [%x[params], %[offsetof_Params_requant_shifts]]\n"
+    "ldp x10, x9, [x21, #0x0]\n"
+    "ldp x28, x27, [x21, #0x10]\n"
+    "cbz x16, 3f\n"
+    "ldr x19, [%x[params], %[offsetof_Params_bias]]\n"
+    "ldr q12, [x19, #0x0]\n"
+    "subs x16, x16, #0x1\n"
+    "mov v14.16b, v12.16b\n"
+    "ldr q17, [x19, #0x10]\n"
+    "add x19, x19, #0x20\n"
+    "str x19, [%x[params], %[offsetof_Params_bias]]\n"
+    "ldr d0, [x17, #0x0]\n"
+    "ldr d1, [x17, #0x8]\n"
+    "ldr d2, [x17, #0x10]\n"
+    "mov v9.16b, v17.16b\n"
+    "mov v16.16b, v12.16b\n"
+    "ldr d3, [x17, #0x18]\n"
+    "ldr d4, [x17, #0x20]\n"
+    "mov v10.16b, v17.16b\n"
+    "mov v18.16b, v12.16b\n"
+    "ldr d5, [x17, #0x28]\n"
+    "ldr d6, [x17, #0x30]\n"
+    "mov v26.16b, v17.16b\n"
+    "usubl v0.8h, v0.8b, v15.8b\n"
+    "ldr d7, [x17, #0x38]\n"
+    "ldr d8, [x17, #0x40]\n"
+    "usubl v1.8h, v1.8b, v15.8b\n"
+    "usubl v2.8h, v2.8b, v15.8b\n"
+    "ldp x23, x22, [x13, #0x0]\n"
+    "ldp x21, x20, [x13, #0x10]\n"
+    "usubl v3.8h, v3.8b, v15.8b\n"
+    "usubl v4.8h, v4.8b, v15.8b\n"
+    "ldr x19, [x13, #0x20]\n"
+    "ldr d31, [x23, x15]\n"
+    "usubl v5.8h, v5.8b, v15.8b\n"
+    "usubl v6.8h, v6.8b, v15.8b\n"
+    "ldr d30, [x22, x15]\n"
+    "ldr d29, [x21, x15]\n"
+    "usubl v7.8h, v7.8b, v15.8b\n"
+    "usubl v8.8h, v8.8b, v15.8b\n"
+    "ldr d28, [x20, x15]\n"
+    "ldr d27, [x19, x15]\n"
+    "ushll v31.8h, v31.8b, #0x0\n"
+    "ushll v30.8h, v30.8b, #0x0\n"
+    "ushll v29.8h, v29.8b, #0x0\n"
+    "ushll v28.8h, v28.8b, #0x0\n"
+    "ushll v27.8h, v27.8b, #0x0\n"
+    "beq 2f\n"
+    "1:"  // Loop
+    "smlal v12.4s, v31.4h, v4.4h\n"
+    "smlal2 v17.4s, v31.8h, v4.8h\n"
+    "ldr x21, [x13, #0x28]\n"
+    "ldr x26, [x13, #0x38]\n"
+    "smlal v14.4s, v31.4h, v3.4h\n"
+    "smlal2 v9.4s, v31.8h, v3.8h\n"
+    "ldr x20, [x13, #0x30]\n"
+    "ldr x25, [x13, #0x40]\n"
+    "smlal v12.4s, v30.4h, v0.4h\n"
+    "smlal2 v17.4s, v30.8h, v0.8h\n"
+    "ldr x19, [x13, #0x48]\n"
+    "ldr d30, [x19, x15]\n"
+    "smlal v14.4s, v29.4h, v2.4h\n"
+    "smlal2 v9.4s, v29.8h, v2.8h\n"
+    "ldr d29, [x20, x15]\n"
+    "ushll v29.8h, v29.8b, #0x0\n"
+    "smlal v16.4s, v31.4h, v1.4h\n"
+    "smlal2 v10.4s, v31.8h, v1.8h\n"
+    "ldr x24, [x13, #0x50]\n"
+    "ldr x23, [x13, #0x58]\n"
+    "smlal v18.4s, v31.4h, v0.4h\n"
+    "smlal2 v26.4s, v31.8h, v0.8h\n"
+    "ldr d31, [x21, x15]\n"
+    "ushll v31.8h, v31.8b, #0x0\n"
+    "smlal v12.4s, v28.4h, v5.4h\n"
+    "smlal2 v17.4s, v28.8h, v5.8h\n"
+    "ushll v30.8h, v30.8b, #0x0\n"
+    "ldr x22, [x13, #0x60]\n"
+    "smlal v14.4s, v28.4h, v4.4h\n"
+    "smlal2 v9.4s, v28.8h, v4.8h\n"
+    "ldr x21, [x13, #0x68]\n"
+    "ldr x20, [x13, #0x70]\n"
+    "smlal v16.4s, v28.4h, v2.4h\n"
+    "smlal2 v10.4s, v28.8h, v2.8h\n"
+    "ldr x19, [x13, #0x78]\n"
+    "ldr q21, [x12, #0x0]\n"
+    "smlal v18.4s, v28.4h, v1.4h\n"
+    "smlal2 v26.4s, v28.8h, v1.8h\n"
+    "ldr d28, [x26, x15]\n"
+    "ushll v28.8h, v28.8b, #0x0\n"
+    "smlal v12.4s, v27.4h, v7.4h\n"
+    "smlal2 v17.4s, v27.8h, v7.8h\n"
+    "ldr q24, [x11, #0x0]\n"
+    "ldr q19, [x12, #0x10]\n"
+    "smlal v14.4s, v27.4h, v6.4h\n"
+    "smlal2 v9.4s, v27.8h, v6.8h\n"
+    "ldr q23, [x11, #0x10]\n"
+    "add x17, x17, #0x48\n"
+    "smlal v16.4s, v31.4h, v6.4h\n"
+    "smlal2 v10.4s, v31.8h, v6.8h\n"
+    "ldr d31, [x25, x15]\n"
+    "ushll v31.8h, v31.8b, #0x0\n"
+    "smlal v18.4s, v27.4h, v3.4h\n"
+    "smlal2 v26.4s, v27.8h, v3.8h\n"
+    "subs x16, x16, #0x1\n"
+    "add x12, x12, #0x20\n"
+    "smlal v12.4s, v28.4h, v1.4h\n"
+    "smlal2 v17.4s, v28.8h, v1.8h\n"
+    "add x11, x11, #0x20\n"
+    "smlal v14.4s, v28.4h, v0.4h\n"
+    "smlal2 v9.4s, v28.8h, v0.8h\n"
+    "ldr d28, [x23, x15]\n"
+    "ushll v28.8h, v28.8b, #0x0\n"
+    "smlal v16.4s, v27.4h, v4.4h\n"
+    "smlal v18.4s, v29.4h, v8.4h\n"
+    "smlal2 v10.4s, v27.8h, v4.8h\n"
+    "smlal2 v26.4s, v29.8h, v8.8h\n"
+    "ldr d29, [x24, x15]\n"
+    "ushll v29.8h, v29.8b, #0x0\n"
+    "smlal v12.4s, v31.4h, v2.4h\n"
+    "smlal2 v17.4s, v31.8h, v2.8h\n"
+    "smlal v14.4s, v31.4h, v1.4h\n"
+    "smlal2 v9.4s, v31.8h, v1.8h\n"
+    "ldr d31, [x22, x15]\n"
+    "ushll v31.8h, v31.8b, #0x0\n"
+    "smlal v16.4s, v30.4h, v5.4h\n"
+    "smlal v18.4s, v30.4h, v4.4h\n"
+    "smlal v12.4s, v30.4h, v8.4h\n"
+    "smlal2 v17.4s, v30.8h, v8.8h\n"
+    "smlal v14.4s, v30.4h, v7.4h\n"
+    "smlal2 v9.4s, v30.8h, v7.8h\n"
+    "smlal2 v10.4s, v30.8h, v5.8h\n"
+    "smlal2 v26.4s, v30.8h, v4.8h\n"
+    "ldr d30, [x21, x15]\n"
+    "ushll v30.8h, v30.8b, #0x0\n"
+    "smlal v16.4s, v29.4h, v0.4h\n"
+    "smlal v18.4s, v28.4h, v2.4h\n"
+    "smlal v12.4s, v29.4h, v3.4h\n"
+    "smlal2 v17.4s, v29.8h, v3.8h\n"
+    "smlal2 v10.4s, v29.8h, v0.8h\n"
+    "ldr d29, [x20, x15]\n"
+    "smlal2 v26.4s, v28.8h, v2.8h\n"
+    "ushll v29.8h, v29.8b, #0x0\n"
+    "smlal v16.4s, v31.4h, v3.4h\n"
+    "smlal v18.4s, v30.4h, v5.4h\n"
+    "smlal v14.4s, v28.4h, v5.4h\n"
+    "smlal2 v9.4s, v28.8h, v5.8h\n"
+    "ldr d28, [x19, x15]\n"
+    "ushll v28.8h, v28.8b, #0x0\n"
+    "smlal2 v10.4s, v31.8h, v3.8h\n"
+    "smlal2 v26.4s, v30.8h, v5.8h\n"
+    "add x15, x15, #0x8\n"
+    "smlal v16.4s, v29.4h, v7.4h\n"
+    "smlal v18.4s, v29.4h, v6.4h\n"
+    "smlal2 v10.4s, v29.8h, v7.8h\n"
+    "smlal2 v26.4s, v29.8h, v6.8h\n"
+    "smlal v12.4s, v31.4h, v6.4h\n"
+    "smlal v14.4s, v30.4h, v8.4h\n"
+    "sqdmulh v12.4s, v12.4s, v21.4s\n"
+    "smlal v16.4s, v28.4h, v8.4h\n"
+    "smlal v18.4s, v28.4h, v7.4h\n"
+    "sqdmulh v14.4s, v14.4s, v21.4s\n"
+    "smlal2 v17.4s, v31.8h, v6.8h\n"
+    "smlal2 v9.4s, v30.8h, v8.8h\n"
+    "sqdmulh v16.4s, v16.4s, v21.4s\n"
+    "smlal2 v10.4s, v28.8h, v8.8h\n"
+    "smlal2 v26.4s, v28.8h, v7.8h\n"
+    "sqdmulh v18.4s, v18.4s, v21.4s\n"
+    "and v29.16b, v12.16b, v24.16b\n"
+    "sqdmulh v17.4s, v17.4s, v19.4s\n"
+    "and v22.16b, v14.16b, v24.16b\n"
+    "sqdmulh v9.4s, v9.4s, v19.4s\n"
+    "and v21.16b, v16.16b, v24.16b\n"
+    "sqdmulh v10.4s, v10.4s, v19.4s\n"
+    "and v20.16b, v18.16b, v24.16b\n"
+    "sqdmulh v26.4s, v26.4s, v19.4s\n"
+    "sshr v29.4s, v29.4s, #0x1f\n"
+    "and v19.16b, v17.16b, v23.16b\n"
+    "sshr v22.4s, v22.4s, #0x1f\n"
+    "and v30.16b, v9.16b, v23.16b\n"
+    "sshr v21.4s, v21.4s, #0x1f\n"
+    "and v3.16b, v10.16b, v23.16b\n"
+    "sshr v20.4s, v20.4s, #0x1f\n"
+    "and v28.16b, v26.16b, v23.16b\n"
+    "sqadd v12.4s, v12.4s, v29.4s\n"
+    "sshr v19.4s, v19.4s, #0x1f\n"
+    "sqadd v14.4s, v14.4s, v22.4s\n"
+    "sshr v30.4s, v30.4s, #0x1f\n"
+    "sqadd v16.4s, v16.4s, v21.4s\n"
+    "sshr v3.4s, v3.4s, #0x1f\n"
+    "sqadd v18.4s, v18.4s, v20.4s\n"
+    "sshr v28.4s, v28.4s, #0x1f\n"
+    "srshl v12.4s, v12.4s, v24.4s\n"
+    "sqadd v17.4s, v17.4s, v19.4s\n"
+    "srshl v14.4s, v14.4s, v24.4s\n"
+    "sqadd v9.4s, v9.4s, v30.4s\n"
+    "srshl v16.4s, v16.4s, v24.4s\n"
+    "sqadd v10.4s, v10.4s, v3.4s\n"
+    "srshl v18.4s, v18.4s, v24.4s\n"
+    "sqadd v26.4s, v26.4s, v28.4s\n"
+    "srshl v17.4s, v17.4s, v23.4s\n"
+    "sqxtn v12.4h, v12.4s\n"
+    "srshl v9.4s, v9.4s, v23.4s\n"
+    "sqxtn v14.4h, v14.4s\n"
+    "srshl v10.4s, v10.4s, v23.4s\n"
+    "sqxtn v16.4h, v16.4s\n"
+    "srshl v26.4s, v26.4s, v23.4s\n"
+    "sqxtn v18.4h, v18.4s\n"
+    "sqxtn2 v12.8h, v17.4s\n"
+    "sqxtn2 v14.8h, v9.4s\n"
+    "sqxtn2 v16.8h, v10.4s\n"
+    "sqxtn2 v18.8h, v26.4s\n"
+    "sqadd v12.8h, v12.8h, v13.8h\n"
+    "sqadd v14.8h, v14.8h, v13.8h\n"
+    "sqadd v16.8h, v16.8h, v13.8h\n"
+    "sqadd v18.8h, v18.8h, v13.8h\n"
+    "smax v12.8h, v12.8h, v11.8h\n"
+    "smax v14.8h, v14.8h, v11.8h\n"
+    "smax v16.8h, v16.8h, v11.8h\n"
+    "smax v18.8h, v18.8h, v11.8h\n"
+    "smin v12.8h, v12.8h, v25.8h\n"
+    "smin v14.8h, v14.8h, v25.8h\n"
+    "smin v16.8h, v16.8h, v25.8h\n"
+    "smin v18.8h, v18.8h, v25.8h\n"
+    "uzp1 v12.16b, v12.16b, v12.16b\n"
+    "uzp1 v14.16b, v14.16b, v14.16b\n"
+    "str d12, [x10, x14]\n"
+    "uzp1 v16.16b, v16.16b, v16.16b\n"
+    "uzp1 v18.16b, v18.16b, v18.16b\n"
+    "str d14, [x9, x14]\n"
+    "str d16, [x28, x14]\n"
+    "str d18, [x27, x14]\n"
+    "ldr x19, [%x[params], %[offsetof_Params_bias]]\n"
+    "ldr q12, [x19, #0x0]\n"
+    "add x14, x14, #0x8\n"
+    "ldr q17, [x19, #0x10]\n"
+    "add x19, x19, #0x20\n"
+    "str x19, [%x[params], %[offsetof_Params_bias]]\n"
+    "ldr d0, [x17, #0x0]\n"
+    "ldr d1, [x17, #0x8]\n"
+    "ldr d2, [x17, #0x10]\n"
+    "mov v14.16b, v12.16b\n"
+    "mov v9.16b, v17.16b\n"
+    "ldr d3, [x17, #0x18]\n"
+    "ldr d4, [x17, #0x20]\n"
+    "mov v16.16b, v12.16b\n"
+    "mov v10.16b, v17.16b\n"
+    "ldr d5, [x17, #0x28]\n"
+    "ldr d6, [x17, #0x30]\n"
+    "mov v18.16b, v12.16b\n"
+    "mov v26.16b, v17.16b\n"
+    "ldr d7, [x17, #0x38]\n"
+    "ldr d8, [x17, #0x40]\n"
+    "usubl v0.8h, v0.8b, v15.8b\n"
+    "usubl v1.8h, v1.8b, v15.8b\n"
+    "ldp x23, x22, [x13, #0x0]\n"
+    "ldp x21, x20, [x13, #0x10]\n"
+    "usubl v2.8h, v2.8b, v15.8b\n"
+    "usubl v3.8h, v3.8b, v15.8b\n"
+    "ldr x19, [x13, #0x20]\n"
+    "ldr d31, [x23, x15]\n"
+    "usubl v4.8h, v4.8b, v15.8b\n"
+    "usubl v5.8h, v5.8b, v15.8b\n"
+    "ldr d30, [x22, x15]\n"
+    "ldr d29, [x21, x15]\n"
+    "usubl v6.8h, v6.8b, v15.8b\n"
+    "usubl v7.8h, v7.8b, v15.8b\n"
+    "ldr d28, [x20, x15]\n"
+    "ldr d27, [x19, x15]\n"
+    "usubl v8.8h, v8.8b, v15.8b\n"
+    "ushll v31.8h, v31.8b, #0x0\n"
+    "ushll v30.8h, v30.8b, #0x0\n"
+    "ushll v29.8h, v29.8b, #0x0\n"
+    "ushll v28.8h, v28.8b, #0x0\n"
+    "ushll v27.8h, v27.8b, #0x0\n"
+    "bgt 1b\n"
+    "2:"  // Tail
+    "smlal v12.4s, v31.4h, v4.4h\n"
+    "smlal2 v17.4s, v31.8h, v4.8h\n"
+    "ldr x21, [x13, #0x28]\n"
+    "ldr x26, [x13, #0x38]\n"
+    "smlal v14.4s, v31.4h, v3.4h\n"
+    "smlal2 v9.4s, v31.8h, v3.8h\n"
+    "ldr x20, [x13, #0x30]\n"
+    "ldr x25, [x13, #0x40]\n"
+    "smlal v12.4s, v30.4h, v0.4h\n"
+    "smlal2 v17.4s, v30.8h, v0.8h\n"
+    "ldr x19, [x13, #0x48]\n"
+    "ldr d30, [x19, x15]\n"
+    "smlal v14.4s, v29.4h, v2.4h\n"
+    "smlal2 v9.4s, v29.8h, v2.8h\n"
+    "ldr d29, [x20, x15]\n"
+    "ushll v29.8h, v29.8b, #0x0\n"
+    "smlal v16.4s, v31.4h, v1.4h\n"
+    "smlal2 v10.4s, v31.8h, v1.8h\n"
+    "ldr x24, [x13, #0x50]\n"
+    "ldr x23, [x13, #0x58]\n"
+    "smlal v18.4s, v31.4h, v0.4h\n"
+    "smlal2 v26.4s, v31.8h, v0.8h\n"
+    "ldr d31, [x21, x15]\n"
+    "ushll v31.8h, v31.8b, #0x0\n"
+    "smlal v12.4s, v28.4h, v5.4h\n"
+    "smlal2 v17.4s, v28.8h, v5.8h\n"
+    "ushll v30.8h, v30.8b, #0x0\n"
+    "ldr x22, [x13, #0x60]\n"
+    "smlal v14.4s, v28.4h, v4.4h\n"
+    "smlal2 v9.4s, v28.8h, v4.8h\n"
+    "ldr x21, [x13, #0x68]\n"
+    "ldr x20, [x13, #0x70]\n"
+    "smlal v16.4s, v28.4h, v2.4h\n"
+    "smlal2 v10.4s, v28.8h, v2.8h\n"
+    "ldr x19, [x13, #0x78]\n"
+    "ldr q21, [x12, #0x0]\n"
+    "smlal v18.4s, v28.4h, v1.4h\n"
+    "smlal2 v26.4s, v28.8h, v1.8h\n"
+    "ldr d28, [x26, x15]\n"
+    "ushll v28.8h, v28.8b, #0x0\n"
+    "smlal v12.4s, v27.4h, v7.4h\n"
+    "smlal2 v17.4s, v27.8h, v7.8h\n"
+    "ldr q24, [x11, #0x0]\n"
+    "ldr q19, [x12, #0x10]\n"
+    "smlal v14.4s, v27.4h, v6.4h\n"
+    "smlal2 v9.4s, v27.8h, v6.8h\n"
+    "ldr q23, [x11, #0x10]\n"
+    "tst x8, #0x7\n"
+    "smlal v16.4s, v31.4h, v6.4h\n"
+    "smlal2 v10.4s, v31.8h, v6.8h\n"
+    "ldr d31, [x25, x15]\n"
+    "ushll v31.8h, v31.8b, #0x0\n"
+    "smlal v18.4s, v27.4h, v3.4h\n"
+    "smlal2 v26.4s, v27.8h, v3.8h\n"
+    "add x12, x12, #0x20\n"
+    "add x11, x11, #0x20\n"
+    "smlal v12.4s, v28.4h, v1.4h\n"
+    "smlal2 v17.4s, v28.8h, v1.8h\n"
+    "smlal v14.4s, v28.4h, v0.4h\n"
+    "smlal2 v9.4s, v28.8h, v0.8h\n"
+    "ldr d28, [x23, x15]\n"
+    "ushll v28.8h, v28.8b, #0x0\n"
+    "smlal v16.4s, v27.4h, v4.4h\n"
+    "smlal v18.4s, v29.4h, v8.4h\n"
+    "smlal2 v10.4s, v27.8h, v4.8h\n"
+    "smlal2 v26.4s, v29.8h, v8.8h\n"
+    "ldr d29, [x24, x15]\n"
+    "ushll v29.8h, v29.8b, #0x0\n"
+    "smlal v12.4s, v31.4h, v2.4h\n"
+    "smlal2 v17.4s, v31.8h, v2.8h\n"
+    "smlal v14.4s, v31.4h, v1.4h\n"
+    "smlal2 v9.4s, v31.8h, v1.8h\n"
+    "ldr d31, [x22, x15]\n"
+    "ushll v31.8h, v31.8b, #0x0\n"
+    "smlal v16.4s, v30.4h, v5.4h\n"
+    "smlal v18.4s, v30.4h, v4.4h\n"
+    "smlal v12.4s, v30.4h, v8.4h\n"
+    "smlal2 v17.4s, v30.8h, v8.8h\n"
+    "smlal v14.4s, v30.4h, v7.4h\n"
+    "smlal2 v9.4s, v30.8h, v7.8h\n"
+    "smlal2 v10.4s, v30.8h, v5.8h\n"
+    "smlal2 v26.4s, v30.8h, v4.8h\n"
+    "ldr d30, [x21, x15]\n"
+    "ushll v30.8h, v30.8b, #0x0\n"
+    "smlal v16.4s, v29.4h, v0.4h\n"
+    "smlal v18.4s, v28.4h, v2.4h\n"
+    "smlal v12.4s, v29.4h, v3.4h\n"
+    "smlal2 v17.4s, v29.8h, v3.8h\n"
+    "smlal2 v10.4s, v29.8h, v0.8h\n"
+    "ldr d29, [x20, x15]\n"
+    "smlal2 v26.4s, v28.8h, v2.8h\n"
+    "ushll v29.8h, v29.8b, #0x0\n"
+    "smlal v16.4s, v31.4h, v3.4h\n"
+    "smlal v18.4s, v30.4h, v5.4h\n"
+    "smlal v14.4s, v28.4h, v5.4h\n"
+    "smlal2 v9.4s, v28.8h, v5.8h\n"
+    "ldr d28, [x19, x15]\n"
+    "ushll v28.8h, v28.8b, #0x0\n"
+    "smlal2 v10.4s, v31.8h, v3.8h\n"
+    "smlal2 v26.4s, v30.8h, v5.8h\n"
+    "add x15, x15, #0x8\n"
+    "smlal v16.4s, v29.4h, v7.4h\n"
+    "smlal v18.4s, v29.4h, v6.4h\n"
+    "smlal2 v10.4s, v29.8h, v7.8h\n"
+    "smlal2 v26.4s, v29.8h, v6.8h\n"
+    "smlal v12.4s, v31.4h, v6.4h\n"
+    "smlal v14.4s, v30.4h, v8.4h\n"
+    "sqdmulh v12.4s, v12.4s, v21.4s\n"
+    "smlal v16.4s, v28.4h, v8.4h\n"
+    "smlal v18.4s, v28.4h, v7.4h\n"
+    "sqdmulh v14.4s, v14.4s, v21.4s\n"
+    "smlal2 v17.4s, v31.8h, v6.8h\n"
+    "smlal2 v9.4s, v30.8h, v8.8h\n"
+    "sqdmulh v16.4s, v16.4s, v21.4s\n"
+    "smlal2 v10.4s, v28.8h, v8.8h\n"
+    "smlal2 v26.4s, v28.8h, v7.8h\n"
+    "sqdmulh v18.4s, v18.4s, v21.4s\n"
+    "and v29.16b, v12.16b, v24.16b\n"
+    "sqdmulh v17.4s, v17.4s, v19.4s\n"
+    "and v22.16b, v14.16b, v24.16b\n"
+    "sqdmulh v9.4s, v9.4s, v19.4s\n"
+    "and v21.16b, v16.16b, v24.16b\n"
+    "sqdmulh v10.4s, v10.4s, v19.4s\n"
+    "and v20.16b, v18.16b, v24.16b\n"
+    "sqdmulh v26.4s, v26.4s, v19.4s\n"
+    "sshr v29.4s, v29.4s, #0x1f\n"
+    "and v19.16b, v17.16b, v23.16b\n"
+    "sshr v22.4s, v22.4s, #0x1f\n"
+    "and v30.16b, v9.16b, v23.16b\n"
+    "sshr v21.4s, v21.4s, #0x1f\n"
+    "and v3.16b, v10.16b, v23.16b\n"
+    "sshr v20.4s, v20.4s, #0x1f\n"
+    "and v28.16b, v26.16b, v23.16b\n"
+    "sqadd v12.4s, v12.4s, v29.4s\n"
+    "sshr v19.4s, v19.4s, #0x1f\n"
+    "sqadd v14.4s, v14.4s, v22.4s\n"
+    "sshr v30.4s, v30.4s, #0x1f\n"
+    "sqadd v16.4s, v16.4s, v21.4s\n"
+    "sshr v3.4s, v3.4s, #0x1f\n"
+    "sqadd v18.4s, v18.4s, v20.4s\n"
+    "sshr v28.4s, v28.4s, #0x1f\n"
+    "srshl v12.4s, v12.4s, v24.4s\n"
+    "sqadd v17.4s, v17.4s, v19.4s\n"
+    "srshl v14.4s, v14.4s, v24.4s\n"
+    "sqadd v9.4s, v9.4s, v30.4s\n"
+    "srshl v16.4s, v16.4s, v24.4s\n"
+    "sqadd v10.4s, v10.4s, v3.4s\n"
+    "srshl v18.4s, v18.4s, v24.4s\n"
+    "sqadd v26.4s, v26.4s, v28.4s\n"
+    "srshl v17.4s, v17.4s, v23.4s\n"
+    "sqxtn v12.4h, v12.4s\n"
+    "srshl v9.4s, v9.4s, v23.4s\n"
+    "sqxtn v14.4h, v14.4s\n"
+    "srshl v10.4s, v10.4s, v23.4s\n"
+    "sqxtn v16.4h, v16.4s\n"
+    "srshl v26.4s, v26.4s, v23.4s\n"
+    "sqxtn v18.4h, v18.4s\n"
+    "sqxtn2 v12.8h, v17.4s\n"
+    "sqxtn2 v14.8h, v9.4s\n"
+    "sqxtn2 v16.8h, v10.4s\n"
+    "sqxtn2 v18.8h, v26.4s\n"
+    "sqadd v12.8h, v12.8h, v13.8h\n"
+    "sqadd v14.8h, v14.8h, v13.8h\n"
+    "sqadd v16.8h, v16.8h, v13.8h\n"
+    "sqadd v18.8h, v18.8h, v13.8h\n"
+    "smax v12.8h, v12.8h, v11.8h\n"
+    "smax v14.8h, v14.8h, v11.8h\n"
+    "smax v16.8h, v16.8h, v11.8h\n"
+    "smax v18.8h, v18.8h, v11.8h\n"
+    "smin v12.8h, v12.8h, v25.8h\n"
+    "smin v14.8h, v14.8h, v25.8h\n"
+    "smin v16.8h, v16.8h, v25.8h\n"
+    "smin v18.8h, v18.8h, v25.8h\n"
+    "uzp1 v12.16b, v12.16b, v12.16b\n"
+    "uzp1 v14.16b, v14.16b, v14.16b\n"
+    "str d12, [x10, x14]\n"
+    "uzp1 v16.16b, v16.16b, v16.16b\n"
+    "uzp1 v18.16b, v18.16b, v18.16b\n"
+    "str d14, [x9, x14]\n"
+    "str d16, [x28, x14]\n"
+    "str d18, [x27, x14]\n"
+    "add x14, x14, #0x8\n"
+    "beq 64f\n"
+    "add x17, x17, #0x48\n"
+    "3:"  // Oddments
+    "ldr x19, [%x[params], %[offsetof_Params_bias]]\n"
+    "tbz x8, #2, 5f\n"
+    "ld1 { v12.4s }, [x19], #0x10\n"
+    "tbz x8, #1, 4f\n"
+    "ld1 { v17.d }[0], [x19], #0x8\n"
+    "tbz x8, #0, 7f\n"
+    "ld1 { v17.s }[2], [x19]\n"
+    "b 7f\n"
+    "4:"  // Oddments: Load bias: Bit 2: Bit 1: Unset
+    "tbz x8, #0, 7f\n"
+    "ld1 { v17.s }[0], [x19]\n"
+    "b 7f\n"
+    "5:"  // Oddments: Load bias: Bit 2: Unset
+    "tbz x8, #1, 6f\n"
+    "ld1 { v12.d }[0], [x19], #0x8\n"
+    "tbz x8, #0, 7f\n"
+    "ld1 { v12.s }[2], [x19]\n"
+    "b 7f\n"
+    "6:"  // Oddments: Load bias: Bit 2: Unset: Bit 1: Unset
+    "tbz x8, #0, 7f\n"
+    "ld1 { v12.s }[0], [x19]\n"
+    "7:"  // Oddments: Load bias: Bit 2: End
+    "ldr d0, [x17, #0x0]\n"
+    "ldr d1, [x17, #0x8]\n"
+    "mov v14.16b, v12.16b\n"
+    "mov v9.16b, v17.16b\n"
+    "ldr d2, [x17, #0x10]\n"
+    "ldr d3, [x17, #0x18]\n"
+    "mov v16.16b, v12.16b\n"
+    "mov v10.16b, v17.16b\n"
+    "ldr d4, [x17, #0x20]\n"
+    "ldr d5, [x17, #0x28]\n"
+    "mov v18.16b, v12.16b\n"
+    "mov v26.16b, v17.16b\n"
+    "ldr d6, [x17, #0x30]\n"
+    "ldr d7, [x17, #0x38]\n"
+    "usubl v0.8h, v0.8b, v15.8b\n"
+    "usubl v1.8h, v1.8b, v15.8b\n"
+    "ldr d8, [x17, #0x40]\n"
+    "ldp x23, x22, [x13, #0x0]\n"
+    "usubl v2.8h, v2.8b, v15.8b\n"
+    "usubl v3.8h, v3.8b, v15.8b\n"
+    "ldp x21, x20, [x13, #0x10]\n"
+    "ldr x19, [x13, #0x20]\n"
+    "usubl v4.8h, v4.8b, v15.8b\n"
+    "usubl v5.8h, v5.8b, v15.8b\n"
+    "usubl v6.8h, v6.8b, v15.8b\n"
+    "usubl v7.8h, v7.8b, v15.8b\n"
+    "usubl v8.8h, v8.8b, v15.8b\n"
+    "add x23, x23, x15\n"
+    "add x22, x22, x15\n"
+    "add x21, x21, x15\n"
+    "add x20, x20, x15\n"
+    "add x19, x19, x15\n"
+    "tbz x8, #2, 9f\n"
+    "ld1 { v31.s }[0], [x23], #0x4\n"
+    "ld1 { v30.s }[0], [x22], #0x4\n"
+    "ld1 { v29.s }[0], [x21], #0x4\n"
+    "ld1 { v28.s }[0], [x20], #0x4\n"
+    "ld1 { v27.s }[0], [x19], #0x4\n"
+    "tbz x8, #1, 8f\n"
+    "ld1 { v31.h }[2], [x23], #0x2\n"
+    "ld1 { v30.h }[2], [x22], #0x2\n"
+    "ld1 { v29.h }[2], [x21], #0x2\n"
+    "ld1 { v28.h }[2], [x20], #0x2\n"
+    "ld1 { v27.h }[2], [x19], #0x2\n"
+    "tbz x8, #0, 11f\n"
+    "ld1 { v31.b }[6], [x23]\n"
+    "ld1 { v30.b }[6], [x22]\n"
+    "ld1 { v29.b }[6], [x21]\n"
+    "ld1 { v28.b }[6], [x20]\n"
+    "ld1 { v27.b }[6], [x19]\n"
+    "b 11f\n"
+    "8:"  // Oddments: Initial loads: Bit 2: Bit 1: Unset
+    "tbz x8, #0, 11f\n"
+    "ld1 { v31.b }[4], [x23]\n"
+    "ld1 { v30.b }[4], [x22]\n"
+    "ld1 { v29.b }[4], [x21]\n"
+    "ld1 { v28.b }[4], [x20]\n"
+    "ld1 { v27.b }[4], [x19]\n"
+    "b 11f\n"
+    "9:"  // Oddments: Initial loads: Bit 2: Unset
+    "tbz x8, #1, 10f\n"
+    "ld1 { v31.h }[0], [x23], #0x2\n"
+    "ld1 { v30.h }[0], [x22], #0x2\n"
+    "ld1 { v29.h }[0], [x21], #0x2\n"
+    "ld1 { v28.h }[0], [x20], #0x2\n"
+    "ld1 { v27.h }[0], [x19], #0x2\n"
+    "tbz x8, #0, 11f\n"
+    "ld1 { v31.b }[2], [x23]\n"
+    "ld1 { v30.b }[2], [x22]\n"
+    "ld1 { v29.b }[2], [x21]\n"
+    "ld1 { v28.b }[2], [x20]\n"
+    "ld1 { v27.b }[2], [x19]\n"
+    "b 11f\n"
+    "10:"  // Oddments: Initial loads: Bit 2: Unset: Bit 1: Unset
+    "tbz x8, #0, 11f\n"
+    "ld1 { v31.b }[0], [x23]\n"
+    "ld1 { v30.b }[0], [x22]\n"
+    "ld1 { v29.b }[0], [x21]\n"
+    "ld1 { v28.b }[0], [x20]\n"
+    "ld1 { v27.b }[0], [x19]\n"
+    "11:"  // Oddments: Initial loads: Bit 2: End
+    "ushll v31.8h, v31.8b, #0x0\n"
+    "smlal v12.4s, v31.4h, v4.4h\n"
+    "smlal2 v17.4s, v31.8h, v4.8h\n"
+    "ldr x21, [x13, #0x28]\n"
+    "smlal v14.4s, v31.4h, v3.4h\n"
+    "smlal2 v9.4s, v31.8h, v3.8h\n"
+    "ushll v30.8h, v30.8b, #0x0\n"
+    "add x21, x21, x15\n"
+    "ushll v29.8h, v29.8b, #0x0\n"
+    "smlal v16.4s, v31.4h, v1.4h\n"
+    "smlal2 v10.4s, v31.8h, v1.8h\n"
+    "smlal v18.4s, v31.4h, v0.4h\n"
+    "smlal2 v26.4s, v31.8h, v0.8h\n"
+    "ushll v28.8h, v28.8b, #0x0\n"
+    "smlal v12.4s, v30.4h, v0.4h\n"
+    "smlal2 v17.4s, v30.8h, v0.8h\n"
+    "ushll v27.8h, v27.8b, #0x0\n"
+    "smlal v14.4s, v29.4h, v2.4h\n"
+    "smlal2 v9.4s, v29.8h, v2.8h\n"
+    "smlal v12.4s, v28.4h, v5.4h\n"
+    "smlal2 v17.4s, v28.8h, v5.8h\n"
+    "smlal v14.4s, v28.4h, v4.4h\n"
+    "smlal2 v9.4s, v28.8h, v4.8h\n"
+    "smlal v16.4s, v28.4h, v2.4h\n"
+    "smlal2 v10.4s, v28.8h, v2.8h\n"
+    "smlal v18.4s, v28.4h, v1.4h\n"
+    "smlal2 v26.4s, v28.8h, v1.8h\n"
+    "tbz x8, #2, 13f\n"
+    "ld1 { v31.s }[0], [x21], #0x4\n"
+    "tbz x8, #1, 12f\n"
+    "ld1 { v31.h }[2], [x21], #0x2\n"
+    "tbz x8, #0, 15f\n"
+    "ld1 { v31.b }[6], [x21]\n"
+    "b 15f\n"
+    "12:"  // Oddments: Load (3, 0): Bit 2: Bit 1: Unset
+    "tbz x8, #0, 15f\n"
+    "ld1 { v31.b }[4], [x21]\n"
+    "b 15f\n"
+    "13:"  // Oddments: Load (3, 0): Bit 2: Unset
+    "tbz x8, #1, 14f\n"
+    "ld1 { v31.h }[0], [x21], #0x2\n"
+    "tbz x8, #0, 15f\n"
+    "ld1 { v31.b }[2], [x21]\n"
+    "b 15f\n"
+    "14:"  // Oddments: Load (3, 0): Bit 2: Unset: Bit 1: Unset
+    "tbz x8, #0, 15f\n"
+    "ld1 { v31.b }[0], [x21]\n"
+    "15:"  // Oddments: Load (3, 0): Bit 2: End
+    "ushll v31.8h, v31.8b, #0x0\n"
+    "smlal v16.4s, v31.4h, v6.4h\n"
+    "smlal2 v10.4s, v31.8h, v6.8h\n"
+    "ldr x20, [x13, #0x30]\n"
+    "smlal v12.4s, v27.4h, v7.4h\n"
+    "smlal2 v17.4s, v27.8h, v7.8h\n"
+    "add x20, x20, x15\n"
+    "smlal v14.4s, v27.4h, v6.4h\n"
+    "smlal2 v9.4s, v27.8h, v6.8h\n"
+    "smlal v16.4s, v27.4h, v4.4h\n"
+    "smlal2 v10.4s, v27.8h, v4.8h\n"
+    "smlal v18.4s, v27.4h, v3.4h\n"
+    "smlal2 v26.4s, v27.8h, v3.8h\n"
+    "tbz x8, #2, 17f\n"
+    "ld1 { v29.s }[0], [x20], #0x4\n"
+    "tbz x8, #1, 16f\n"
+    "ld1 { v29.h }[2], [x20], #0x2\n"
+    "tbz x8, #0, 19f\n"
+    "ld1 { v29.b }[6], [x20]\n"
+    "b 19f\n"
+    "16:"  // Oddments: Load (3, 3): Bit 2: Bit 1: Unset
+    "tbz x8, #0, 19f\n"
+    "ld1 { v29.b }[4], [x20]\n"
+    "b 19f\n"
+    "17:"  // Oddments: Load (3, 3): Bit 2: Unset
+    "tbz x8, #1, 18f\n"
+    "ld1 { v29.h }[0], [x20], #0x2\n"
+    "tbz x8, #0, 19f\n"
+    "ld1 { v29.b }[2], [x20]\n"
+    "b 19f\n"
+    "18:"  // Oddments: Load (3, 3): Bit 2: Unset: Bit 1: Unset
+    "tbz x8, #0, 19f\n"
+    "ld1 { v29.b }[0], [x20]\n"
+    "19:"  // Oddments: Load (3, 3): Bit 2: End
+    "ushll v29.8h, v29.8b, #0x0\n"
+    "ldr x26, [x13, #0x38]\n"
+    "smlal v18.4s, v29.4h, v8.4h\n"
+    "smlal2 v26.4s, v29.8h, v8.8h\n"
+    "add x26, x26, x15\n"
+    "tbz x8, #2, 21f\n"
+    "ld1 { v28.s }[0], [x26], #0x4\n"
+    "tbz x8, #1, 20f\n"
+    "ld1 { v28.h }[2], [x26], #0x2\n"
+    "tbz x8, #0, 23f\n"
+    "ld1 { v28.b }[6], [x26]\n"
+    "b 23f\n"
+    "20:"  // Oddments: Load (0, 1): Bit 2: Bit 1: Unset
+    "tbz x8, #0, 23f\n"
+    "ld1 { v28.b }[4], [x26]\n"
+    "b 23f\n"
+    "21:"  // Oddments: Load (0, 1): Bit 2: Unset
+    "tbz x8, #1, 22f\n"
+    "ld1 { v28.h }[0], [x26], #0x2\n"
+    "tbz x8, #0, 23f\n"
+    "ld1 { v28.b }[2], [x26]\n"
+    "b 23f\n"
+    "22:"  // Oddments: Load (0, 1): Bit 2: Unset: Bit 1: Unset
+    "tbz x8, #0, 23f\n"
+    "ld1 { v28.b }[0], [x26]\n"
+    "23:"  // Oddments: Load (0, 1): Bit 2: End
+    "ushll v28.8h, v28.8b, #0x0\n"
+    "ldr x25, [x13, #0x40]\n"
+    "smlal v12.4s, v28.4h, v1.4h\n"
+    "smlal2 v17.4s, v28.8h, v1.8h\n"
+    "smlal v14.4s, v28.4h, v0.4h\n"
+    "smlal2 v9.4s, v28.8h, v0.8h\n"
+    "add x25, x25, x15\n"
+    "tbz x8, #2, 25f\n"
+    "ld1 { v31.s }[0], [x25], #0x4\n"
+    "tbz x8, #1, 24f\n"
+    "ld1 { v31.h }[2], [x25], #0x2\n"
+    "tbz x8, #0, 27f\n"
+    "ld1 { v31.b }[6], [x25]\n"
+    "b 27f\n"
+    "24:"  // Oddments: Load (0, 2): Bit 2: Bit 1: Unset
+    "tbz x8, #0, 27f\n"
+    "ld1 { v31.b }[4], [x25]\n"
+    "b 27f\n"
+    "25:"  // Oddments: Load (0, 2): Bit 2: Unset
+    "tbz x8, #1, 26f\n"
+    "ld1 { v31.h }[0], [x25], #0x2\n"
+    "tbz x8, #0, 27f\n"
+    "ld1 { v31.b }[2], [x25]\n"
+    "b 27f\n"
+    "26:"  // Oddments: Load (0, 2): Bit 2: Unset: Bit 1: Unset
+    "tbz x8, #0, 27f\n"
+    "ld1 { v31.b }[0], [x25]\n"
+    "27:"  // Oddments: Load (0, 2): Bit 2: End
+    "ushll v31.8h, v31.8b, #0x0\n"
+    "ldr x19, [x13, #0x48]\n"
+    "smlal v12.4s, v31.4h, v2.4h\n"
+    "smlal2 v17.4s, v31.8h, v2.8h\n"
+    "smlal v14.4s, v31.4h, v1.4h\n"
+    "smlal2 v9.4s, v31.8h, v1.8h\n"
+    "add x19, x19, x15\n"
+    "tbz x8, #2, 29f\n"
+    "ld1 { v30.s }[0], [x19], #0x4\n"
+    "tbz x8, #1, 28f\n"
+    "ld1 { v30.h }[2], [x19], #0x2\n"
+    "tbz x8, #0, 31f\n"
+    "ld1 { v30.b }[6], [x19]\n"
+    "b 31f\n"
+    "28:"  // Oddments: Load (2, 2): Bit 2: Bit 1: Unset
+    "tbz x8, #0, 31f\n"
+    "ld1 { v30.b }[4], [x19]\n"
+    "b 31f\n"
+    "29:"  // Oddments: Load (2, 2): Bit 2: Unset
+    "tbz x8, #1, 30f\n"
+    "ld1 { v30.h }[0], [x19], #0x2\n"
+    "tbz x8, #0, 31f\n"
+    "ld1 { v30.b }[2], [x19]\n"
+    "b 31f\n"
+    "30:"  // Oddments: Load (2, 2): Bit 2: Unset: Bit 1: Unset
+    "tbz x8, #0, 31f\n"
+    "ld1 { v30.b }[0], [x19]\n"
+    "31:"  // Oddments: Load (2, 2): Bit 2: End
+    "ushll v30.8h, v30.8b, #0x0\n"
+    "ldr x24, [x13, #0x50]\n"
+    "smlal v12.4s, v30.4h, v8.4h\n"
+    "smlal2 v17.4s, v30.8h, v8.8h\n"
+    "smlal v14.4s, v30.4h, v7.4h\n"
+    "smlal2 v9.4s, v30.8h, v7.8h\n"
+    "add x24, x24, x15\n"
+    "smlal v16.4s, v30.4h, v5.4h\n"
+    "smlal2 v10.4s, v30.8h, v5.8h\n"
+    "smlal v18.4s, v30.4h, v4.4h\n"
+    "smlal2 v26.4s, v30.8h, v4.8h\n"
+    "tbz x8, #2, 33f\n"
+    "ld1 { v29.s }[0], [x24], #0x4\n"
+    "tbz x8, #1, 32f\n"
+    "ld1 { v29.h }[2], [x24], #0x2\n"
+    "tbz x8, #0, 35f\n"
+    "ld1 { v29.b }[6], [x24]\n"
+    "b 35f\n"
+    "32:"  // Oddments: Load (1, 0): Bit 2: Bit 1: Unset
+    "tbz x8, #0, 35f\n"
+    "ld1 { v29.b }[4], [x24]\n"
+    "b 35f\n"
+    "33:"  // Oddments: Load (1, 0): Bit 2: Unset
+    "tbz x8, #1, 34f\n"
+    "ld1 { v29.h }[0], [x24], #0x2\n"
+    "tbz x8, #0, 35f\n"
+    "ld1 { v29.b }[2], [x24]\n"
+    "b 35f\n"
+    "34:"  // Oddments: Load (1, 0): Bit 2: Unset: Bit 1: Unset
+    "tbz x8, #0, 35f\n"
+    "ld1 { v29.b }[0], [x24]\n"
+    "35:"  // Oddments: Load (1, 0): Bit 2: End
+    "ushll v29.8h, v29.8b, #0x0\n"
+    "ldr x23, [x13, #0x58]\n"
+    "smlal v12.4s, v29.4h, v3.4h\n"
+    "smlal2 v17.4s, v29.8h, v3.8h\n"
+    "smlal v16.4s, v29.4h, v0.4h\n"
+    "smlal2 v10.4s, v29.8h, v0.8h\n"
+    "add x23, x23, x15\n"
+    "tbz x8, #2, 37f\n"
+    "ld1 { v28.s }[0], [x23], #0x4\n"
+    "tbz x8, #1, 36f\n"
+    "ld1 { v28.h }[2], [x23], #0x2\n"
+    "tbz x8, #0, 39f\n"
+    "ld1 { v28.b }[6], [x23]\n"
+    "b 39f\n"
+    "36:"  // Oddments: Load (1, 3): Bit 2: Bit 1: Unset
+    "tbz x8, #0, 39f\n"
+    "ld1 { v28.b }[4], [x23]\n"
+    "b 39f\n"
+    "37:"  // Oddments: Load (1, 3): Bit 2: Unset
+    "tbz x8, #1, 38f\n"
+    "ld1 { v28.h }[0], [x23], #0x2\n"
+    "tbz x8, #0, 39f\n"
+    "ld1 { v28.b }[2], [x23]\n"
+    "b 39f\n"
+    "38:"  // Oddments: Load (1, 3): Bit 2: Unset: Bit 1: Unset
+    "tbz x8, #0, 39f\n"
+    "ld1 { v28.b }[0], [x23]\n"
+    "39:"  // Oddments: Load (1, 3): Bit 2: End
+    "ushll v28.8h, v28.8b, #0x0\n"
+    "ldr x22, [x13, #0x60]\n"
+    "smlal v14.4s, v28.4h, v5.4h\n"
+    "smlal2 v9.4s, v28.8h, v5.8h\n"
+    "smlal v18.4s, v28.4h, v2.4h\n"
+    "smlal2 v26.4s, v28.8h, v2.8h\n"
+    "add x22, x22, x15\n"
+    "tbz x8, #2, 41f\n"
+    "ld1 { v31.s }[0], [x22], #0x4\n"
+    "tbz x8, #1, 40f\n"
+    "ld1 { v31.h }[2], [x22], #0x2\n"
+    "tbz x8, #0, 43f\n"
+    "ld1 { v31.b }[6], [x22]\n"
+    "b 43f\n"
+    "40:"  // Oddments: Load (2, 0): Bit 2: Bit 1: Unset
+    "tbz x8, #0, 43f\n"
+    "ld1 { v31.b }[4], [x22]\n"
+    "b 43f\n"
+    "41:"  // Oddments: Load (2, 0): Bit 2: Unset
+    "tbz x8, #1, 42f\n"
+    "ld1 { v31.h }[0], [x22], #0x2\n"
+    "tbz x8, #0, 43f\n"
+    "ld1 { v31.b }[2], [x22]\n"
+    "b 43f\n"
+    "42:"  // Oddments: Load (2, 0): Bit 2: Unset: Bit 1: Unset
+    "tbz x8, #0, 43f\n"
+    "ld1 { v31.b }[0], [x22]\n"
+    "43:"  // Oddments: Load (2, 0): Bit 2: End
+    "ushll v31.8h, v31.8b, #0x0\n"
+    "ldr x21, [x13, #0x68]\n"
+    "smlal v12.4s, v31.4h, v6.4h\n"
+    "smlal2 v17.4s, v31.8h, v6.8h\n"
+    "smlal v16.4s, v31.4h, v3.4h\n"
+    "smlal2 v10.4s, v31.8h, v3.8h\n"
+    "add x21, x21, x15\n"
+    "tbz x8, #2, 45f\n"
+    "ld1 { v30.s }[0], [x21], #0x4\n"
+    "tbz x8, #1, 44f\n"
+    "ld1 { v30.h }[2], [x21], #0x2\n"
+    "tbz x8, #0, 47f\n"
+    "ld1 { v30.b }[6], [x21]\n"
+    "b 47f\n"
+    "44:"  // Oddments: Load (2, 3): Bit 2: Bit 1: Unset
+    "tbz x8, #0, 47f\n"
+    "ld1 { v30.b }[4], [x21]\n"
+    "b 47f\n"
+    "45:"  // Oddments: Load (2, 3): Bit 2: Unset
+    "tbz x8, #1, 46f\n"
+    "ld1 { v30.h }[0], [x21], #0x2\n"
+    "tbz x8, #0, 47f\n"
+    "ld1 { v30.b }[2], [x21]\n"
+    "b 47f\n"
+    "46:"  // Oddments: Load (2, 3): Bit 2: Unset: Bit 1: Unset
+    "tbz x8, #0, 47f\n"
+    "ld1 { v30.b }[0], [x21]\n"
+    "47:"  // Oddments: Load (2, 3): Bit 2: End
+    "ushll v30.8h, v30.8b, #0x0\n"
+    "ldr x20, [x13, #0x70]\n"
+    "smlal v14.4s, v30.4h, v8.4h\n"
+    "smlal2 v9.4s, v30.8h, v8.8h\n"
+    "smlal v18.4s, v30.4h, v5.4h\n"
+    "smlal2 v26.4s, v30.8h, v5.8h\n"
+    "add x20, x20, x15\n"
+    "tbz x8, #2, 49f\n"
+    "ld1 { v29.s }[0], [x20], #0x4\n"
+    "tbz x8, #1, 48f\n"
+    "ld1 { v29.h }[2], [x20], #0x2\n"
+    "tbz x8, #0, 51f\n"
+    "ld1 { v29.b }[6], [x20]\n"
+    "b 51f\n"
+    "48:"  // Oddments: Load (3, 1): Bit 2: Bit 1: Unset
+    "tbz x8, #0, 51f\n"
+    "ld1 { v29.b }[4], [x20]\n"
+    "b 51f\n"
+    "49:"  // Oddments: Load (3, 1): Bit 2: Unset
+    "tbz x8, #1, 50f\n"
+    "ld1 { v29.h }[0], [x20], #0x2\n"
+    "tbz x8, #0, 51f\n"
+    "ld1 { v29.b }[2], [x20]\n"
+    "b 51f\n"
+    "50:"  // Oddments: Load (3, 1): Bit 2: Unset: Bit 1: Unset
+    "tbz x8, #0, 51f\n"
+    "ld1 { v29.b }[0], [x20]\n"
+    "51:"  // Oddments: Load (3, 1): Bit 2: End
+    "ushll v29.8h, v29.8b, #0x0\n"
+    "ldr x19, [x13, #0x78]\n"
+    "smlal v16.4s, v29.4h, v7.4h\n"
+    "smlal2 v10.4s, v29.8h, v7.8h\n"
+    "smlal v18.4s, v29.4h, v6.4h\n"
+    "smlal2 v26.4s, v29.8h, v6.8h\n"
+    "add x19, x19, x15\n"
+    "tbz x8, #2, 53f\n"
+    "ld1 { v28.s }[0], [x19], #0x4\n"
+    "tbz x8, #1, 52f\n"
+    "ld1 { v28.h }[2], [x19], #0x2\n"
+    "tbz x8, #0, 55f\n"
+    "ld1 { v28.b }[6], [x19]\n"
+    "b 55f\n"
+    "52:"  // Oddments: Load (3, 2): Bit 2: Bit 1: Unset
+    "tbz x8, #0, 55f\n"
+    "ld1 { v28.b }[4], [x19]\n"
+    "b 55f\n"
+    "53:"  // Oddments: Load (3, 2): Bit 2: Unset
+    "tbz x8, #1, 54f\n"
+    "ld1 { v28.h }[0], [x19], #0x2\n"
+    "tbz x8, #0, 55f\n"
+    "ld1 { v28.b }[2], [x19]\n"
+    "b 55f\n"
+    "54:"  // Oddments: Load (3, 2): Bit 2: Unset: Bit 1: Unset
+    "tbz x8, #0, 55f\n"
+    "ld1 { v28.b }[0], [x19]\n"
+    "55:"  // Oddments: Load (3, 2): Bit 2: End
+    "ushll v28.8h, v28.8b, #0x0\n"
+    "smlal v16.4s, v28.4h, v8.4h\n"
+    "smlal2 v10.4s, v28.8h, v8.8h\n"
+    "smlal v18.4s, v28.4h, v7.4h\n"
+    "smlal2 v26.4s, v28.8h, v7.8h\n"
+    "tbz x8, #2, 57f\n"
+    "ld1 { v21.4s }, [x12], #0x10\n"
+    "ld1 { v24.4s }, [x11], #0x10\n"
+    "tbz x8, #1, 56f\n"
+    "ld1 { v19.d }[0], [x12], #0x8\n"
+    "ld1 { v23.d }[0], [x11], #0x8\n"
+    "tbz x8, #0, 59f\n"
+    "ld1 { v19.s }[2], [x12]\n"
+    "ld1 { v23.s }[2], [x11]\n"
+    "b 59f\n"
+    "56:"  // Oddments: Load requant params: Bit 2: Bit 1: Unset
+    "tbz x8, #0, 59f\n"
+    "ld1 { v19.s }[0], [x12]\n"
+    "ld1 { v23.s }[0], [x11]\n"
+    "b 59f\n"
+    "57:"  // Oddments: Load requant params: Bit 2: Unset
+    "tbz x8, #1, 58f\n"
+    "ld1 { v21.d }[0], [x12], #0x8\n"
+    "ld1 { v24.d }[0], [x11], #0x8\n"
+    "tbz x8, #0, 59f\n"
+    "ld1 { v21.s }[2], [x12]\n"
+    "ld1 { v24.s }[2], [x11]\n"
+    "b 59f\n"
+    "58:"  // Oddments: Load requant params: Bit 2: Unset: Bit 1: Unset
+    "tbz x8, #0, 59f\n"
+    "ld1 { v21.s }[0], [x12]\n"
+    "ld1 { v24.s }[0], [x11]\n"
+    "59:"  // Oddments: Load requant params: Bit 2: End
+    "sqdmulh v12.4s, v12.4s, v21.4s\n"
+    "sqdmulh v14.4s, v14.4s, v21.4s\n"
+    "add x10, x10, x14\n"
+    "add x9, x9, x14\n"
+    "sqdmulh v16.4s, v16.4s, v21.4s\n"
+    "sqdmulh v18.4s, v18.4s, v21.4s\n"
+    "add x28, x28, x14\n"
+    "add x27, x27, x14\n"
+    "and v29.16b, v12.16b, v24.16b\n"
+    "sqdmulh v17.4s, v17.4s, v19.4s\n"
+    "and v22.16b, v14.16b, v24.16b\n"
+    "sqdmulh v9.4s, v9.4s, v19.4s\n"
+    "and v21.16b, v16.16b, v24.16b\n"
+    "sqdmulh v10.4s, v10.4s, v19.4s\n"
+    "and v20.16b, v18.16b, v24.16b\n"
+    "sqdmulh v26.4s, v26.4s, v19.4s\n"
+    "sshr v29.4s, v29.4s, #0x1f\n"
+    "and v19.16b, v17.16b, v23.16b\n"
+    "sshr v22.4s, v22.4s, #0x1f\n"
+    "and v30.16b, v9.16b, v23.16b\n"
+    "sshr v21.4s, v21.4s, #0x1f\n"
+    "and v3.16b, v10.16b, v23.16b\n"
+    "sshr v20.4s, v20.4s, #0x1f\n"
+    "and v28.16b, v26.16b, v23.16b\n"
+    "sqadd v12.4s, v12.4s, v29.4s\n"
+    "sshr v19.4s, v19.4s, #0x1f\n"
+    "sqadd v14.4s, v14.4s, v22.4s\n"
+    "sshr v30.4s, v30.4s, #0x1f\n"
+    "sqadd v16.4s, v16.4s, v21.4s\n"
+    "sshr v3.4s, v3.4s, #0x1f\n"
+    "sqadd v18.4s, v18.4s, v20.4s\n"
+    "sshr v28.4s, v28.4s, #0x1f\n"
+    "srshl v12.4s, v12.4s, v24.4s\n"
+    "sqadd v17.4s, v17.4s, v19.4s\n"
+    "srshl v14.4s, v14.4s, v24.4s\n"
+    "sqadd v9.4s, v9.4s, v30.4s\n"
+    "srshl v16.4s, v16.4s, v24.4s\n"
+    "sqadd v10.4s, v10.4s, v3.4s\n"
+    "srshl v18.4s, v18.4s, v24.4s\n"
+    "sqadd v26.4s, v26.4s, v28.4s\n"
+    "srshl v17.4s, v17.4s, v23.4s\n"
+    "sqxtn v12.4h, v12.4s\n"
+    "srshl v9.4s, v9.4s, v23.4s\n"
+    "sqxtn v14.4h, v14.4s\n"
+    "srshl v10.4s, v10.4s, v23.4s\n"
+    "sqxtn v16.4h, v16.4s\n"
+    "srshl v26.4s, v26.4s, v23.4s\n"
+    "sqxtn v18.4h, v18.4s\n"
+    "sqxtn2 v12.8h, v17.4s\n"
+    "sqxtn2 v14.8h, v9.4s\n"
+    "sqxtn2 v16.8h, v10.4s\n"
+    "sqxtn2 v18.8h, v26.4s\n"
+    "sqadd v12.8h, v12.8h, v13.8h\n"
+    "sqadd v14.8h, v14.8h, v13.8h\n"
+    "sqadd v16.8h, v16.8h, v13.8h\n"
+    "sqadd v18.8h, v18.8h, v13.8h\n"
+    "smax v12.8h, v12.8h, v11.8h\n"
+    "smax v14.8h, v14.8h, v11.8h\n"
+    "smax v16.8h, v16.8h, v11.8h\n"
+    "smax v18.8h, v18.8h, v11.8h\n"
+    "smin v12.8h, v12.8h, v25.8h\n"
+    "smin v14.8h, v14.8h, v25.8h\n"
+    "smin v16.8h, v16.8h, v25.8h\n"
+    "smin v18.8h, v18.8h, v25.8h\n"
+    "uzp1 v12.16b, v12.16b, v12.16b\n"
+    "uzp1 v14.16b, v14.16b, v14.16b\n"
+    "uzp1 v16.16b, v16.16b, v16.16b\n"
+    "uzp1 v18.16b, v18.16b, v18.16b\n"
+    "tbz x8, #2, 61f\n"
+    "st1 { v12.s }[0], [x10], #0x4\n"
+    "st1 { v14.s }[0], [x9], #0x4\n"
+    "st1 { v16.s }[0], [x28], #0x4\n"
+    "st1 { v18.s }[0], [x27], #0x4\n"
+    "tbz x8, #1, 60f\n"
+    "st1 { v12.h }[2], [x10], #0x2\n"
+    "st1 { v14.h }[2], [x9], #0x2\n"
+    "st1 { v16.h }[2], [x28], #0x2\n"
+    "st1 { v18.h }[2], [x27], #0x2\n"
+    "tbz x8, #0, 63f\n"
+    "st1 { v12.b }[6], [x10], #0x1\n"
+    "st1 { v14.b }[6], [x9], #0x1\n"
+    "st1 { v16.b }[6], [x28], #0x1\n"
+    "st1 { v18.b }[6], [x27], #0x1\n"
+    "b 63f\n"
+    "60:"  // Oddments: Bit 2: Bit 1: Unset
+    "tbz x8, #0, 63f\n"
+    "st1 { v12.b }[4], [x10], #0x1\n"
+    "st1 { v14.b }[4], [x9], #0x1\n"
+    "st1 { v16.b }[4], [x28], #0x1\n"
+    "st1 { v18.b }[4], [x27], #0x1\n"
+    "b 63f\n"
+    "61:"  // Oddments: Bit 2: Unset
+    "tbz x8, #1, 62f\n"
+    "st1 { v12.h }[0], [x10], #0x2\n"
+    "st1 { v14.h }[0], [x9], #0x2\n"
+    "st1 { v16.h }[0], [x28], #0x2\n"
+    "st1 { v18.h }[0], [x27], #0x2\n"
+    "tbz x8, #0, 63f\n"
+    "st1 { v12.b }[2], [x10], #0x1\n"
+    "st1 { v14.b }[2], [x9], #0x1\n"
+    "st1 { v16.b }[2], [x28], #0x1\n"
+    "st1 { v18.b }[2], [x27], #0x1\n"
+    "b 63f\n"
+    "62:"  // Oddments: Bit 2: Unset: Bit 1: Unset
+    "tbz x8, #0, 63f\n"
+    "st1 { v12.b }[0], [x10], #0x1\n"
+    "st1 { v14.b }[0], [x9], #0x1\n"
+    "st1 { v16.b }[0], [x28], #0x1\n"
+    "st1 { v18.b }[0], [x27], #0x1\n"
+    "63:"  // Oddments: Bit 2: End
+    "64:"  // End
+    :
+    : [offsetof_Params_bias] "I" (offsetof(Params, bias)), [offsetof_Params_inptrs] "I" (offsetof(Params, inptrs)), [offsetof_Params_n_channels] "I" (offsetof(Params, n_channels)), [offsetof_Params_outptrs] "I" (offsetof(Params, outptrs)), [offsetof_Params_requant] "I" (offsetof(Params, requant)), [offsetof_Params_requant_muls] "I" (offsetof(Params, requant_muls)), [offsetof_Params_requant_shifts] "I" (offsetof(Params, requant_shifts)), [offsetof_Params_weights] "I" (offsetof(Params, weights)), [offsetof_Requantize32_b_offset] "I" (offsetof(arm_gemm::Requantize32, b_offset)), [offsetof_Requantize32_c_offset] "I" (offsetof(arm_gemm::Requantize32, c_offset)), [offsetof_Requantize32_maxval] "I" (offsetof(arm_gemm::Requantize32, maxval)), [offsetof_Requantize32_minval] "I" (offsetof(arm_gemm::Requantize32, minval)), [params] "r" (&params)
+    : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x8", "x9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17", "x19", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"
+  );
+}
+
+}  // namespace depthwise
+}  // namespace arm_conv
+
+#endif  // defined(__aarch64__)
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8qa_nhwc_3x3_s2_output2x2_mla_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8qa_nhwc_3x3_s2_output2x2_mla_depthfirst.hpp
new file mode 100644
index 0000000..b479dbf
--- /dev/null
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8qa_nhwc_3x3_s2_output2x2_mla_depthfirst.hpp
@@ -0,0 +1,63 @@
+/*
+ * Copyright (c) 2022 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "src/core/NEON/kernels/arm_gemm/utils.hpp"
+
+#include "src/core/NEON/kernels/arm_conv/depthwise/interleaves/list.hpp"
+
+#include <cstdint>
+
+#pragma once
+
+#if defined(__aarch64__)
+
+namespace arm_conv {
+namespace depthwise {
+
+void a64_u8qa_nhwc_3x3_s2_output2x2_mla_depthfirst_impl(unsigned int, const uint8_t *const *, const uint8_t *, const int32_t *, const arm_gemm::Requantize32 &, const int32_t *, const int32_t *, uint8_t *const *);
+
+class a64_u8qa_nhwc_3x3_s2_output2x2_mla_depthfirst : public DepthwiseDepthfirstStrategy<uint8_t, uint8_t, uint8_t, int32_t>
+{
+  using Parent = DepthwiseDepthfirstStrategy<uint8_t, uint8_t, uint8_t, int32_t>;
+
+  public:
+  constexpr static unsigned int kernel_rows = 3;
+  constexpr static unsigned int kernel_cols = 3;
+
+  constexpr static unsigned int stride_rows = 2;
+  constexpr static unsigned int stride_cols = 2;
+
+  a64_u8qa_nhwc_3x3_s2_output2x2_mla_depthfirst(const CPUInfo *) : Parent(2, 2, 3, 3, 2, 2) {}
+
+  arm_gemm::VLType get_vl_type(void) const override { return arm_gemm::VLType::None; }
+
+  Parent::KernelType kernel = a64_u8qa_nhwc_3x3_s2_output2x2_mla_depthfirst_impl;
+  Parent::KernelType get_kernel(void) const override { return kernel; }
+  unsigned int get_accumulator_depth_vl(void) const override { return 2; }
+};
+
+}  // namespace depthwise
+}  // namespace arm_conv
+
+#endif  // defined(__aarch64__)
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8qa_nhwc_3x3_s2_output2x2_mla_depthfirst/generic.cpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8qa_nhwc_3x3_s2_output2x2_mla_depthfirst/generic.cpp
new file mode 100644
index 0000000..49f69c4
--- /dev/null
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8qa_nhwc_3x3_s2_output2x2_mla_depthfirst/generic.cpp
@@ -0,0 +1,1395 @@
+/*
+ * Copyright (c) 2022 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "arm_gemm.hpp"
+
+#include <cstddef>
+#include <cstdint>
+
+#if defined(__aarch64__)
+
+namespace arm_conv {
+namespace depthwise {
+
+void a64_u8qa_nhwc_3x3_s2_output2x2_mla_depthfirst_impl(
+  const unsigned int n_channels,
+  const uint8_t *const *const inptrs,
+  const uint8_t *const weights,
+  const int32_t *const bias,
+  const arm_gemm::Requantize32 &qp,
+  const int32_t *const requant_muls,
+  const int32_t *const requant_shifts,
+  uint8_t *const *const outptrs
+)
+{
+  struct Params
+  {
+    long unsigned int n_channels;
+    const void *weights;
+    const int32_t *bias;
+    const arm_gemm::Requantize32 *requant;
+    const int32_t *const requant_muls;
+    const int32_t *const requant_shifts;
+    uint8_t *const *const outptrs;
+    const uint8_t *inptrs[25];
+
+    Params(
+      long unsigned int n_channels,
+      const uint8_t *const *inptrs_raw,
+      const void *const weights,
+      const int32_t *const bias,
+      const arm_gemm::Requantize32 &qp,
+      const int32_t *const requant_muls,
+      const int32_t *const requant_shifts,
+      uint8_t *const *outptrs
+    ) : n_channels(n_channels), weights(weights), bias(bias),
+        requant(&qp), requant_muls(requant_muls),
+        requant_shifts(requant_shifts), outptrs(outptrs)
+    {
+      inptrs[0] = inptrs_raw[12];
+      inptrs[1] = inptrs_raw[0];
+      inptrs[2] = inptrs_raw[1];
+      inptrs[3] = inptrs_raw[3];
+      inptrs[4] = inptrs_raw[4];
+      inptrs[5] = inptrs_raw[5];
+      inptrs[6] = inptrs_raw[6];
+      inptrs[7] = inptrs_raw[2];
+      inptrs[8] = inptrs_raw[8];
+      inptrs[9] = inptrs_raw[9];
+      inptrs[10] = inptrs_raw[7];
+      inptrs[11] = inptrs_raw[15];
+      inptrs[12] = inptrs_raw[10];
+      inptrs[13] = inptrs_raw[16];
+      inptrs[14] = inptrs_raw[11];
+      inptrs[15] = inptrs_raw[18];
+      inptrs[16] = inptrs_raw[13];
+      inptrs[17] = inptrs_raw[19];
+      inptrs[18] = inptrs_raw[20];
+      inptrs[19] = inptrs_raw[14];
+      inptrs[20] = inptrs_raw[21];
+      inptrs[21] = inptrs_raw[17];
+      inptrs[22] = inptrs_raw[23];
+      inptrs[23] = inptrs_raw[22];
+      inptrs[24] = inptrs_raw[24];
+
+    }
+  };
+
+  const Params params(n_channels, inptrs, weights, bias, qp,
+                      requant_muls, requant_shifts, outptrs);
+
+  __asm__ __volatile__(
+    "ldr x19, [%x[params], %[offsetof_Params_requant]]\n"
+    "ldr x8, [%x[params], %[offsetof_Params_n_channels]]\n"
+    "add x23, x19, %[offsetof_Requantize32_b_offset]\n"
+    "add x22, x19, %[offsetof_Requantize32_c_offset]\n"
+    "ldr x21, [%x[params], %[offsetof_Params_outptrs]]\n"
+    "add x20, x19, %[offsetof_Requantize32_minval]\n"
+    "add x19, x19, %[offsetof_Requantize32_maxval]\n"
+    "ldr x17, [%x[params], %[offsetof_Params_weights]]\n"
+    "ld1r { v16.16b }, [x23]\n"
+    "ld1r { v12.8h }, [x22]\n"
+    "lsr x16, x8, #0x3\n"
+    "mov x15, #0x0\n"
+    "ld1r { v14.8h }, [x20]\n"
+    "ld1r { v21.8h }, [x19]\n"
+    "mov x14, #0x0\n"
+    "add x13, %x[params], %[offsetof_Params_inptrs]\n"
+    "ldr x12, [%x[params], %[offsetof_Params_requant_muls]]\n"
+    "ldr x11, [%x[params], %[offsetof_Params_requant_shifts]]\n"
+    "ldp x10, x9, [x21, #0x0]\n"
+    "ldp x28, x27, [x21, #0x10]\n"
+    "cbz x16, 3f\n"
+    "ldr x19, [%x[params], %[offsetof_Params_bias]]\n"
+    "ldr q15, [x19, #0x0]\n"
+    "subs x16, x16, #0x1\n"
+    "mov v13.16b, v15.16b\n"
+    "ldr q18, [x19, #0x10]\n"
+    "add x19, x19, #0x20\n"
+    "str x19, [%x[params], %[offsetof_Params_bias]]\n"
+    "ldr d0, [x17, #0x0]\n"
+    "ldr d1, [x17, #0x8]\n"
+    "ldr d2, [x17, #0x10]\n"
+    "mov v17.16b, v18.16b\n"
+    "mov v11.16b, v15.16b\n"
+    "ldr d3, [x17, #0x18]\n"
+    "ldr d4, [x17, #0x20]\n"
+    "mov v10.16b, v18.16b\n"
+    "mov v23.16b, v15.16b\n"
+    "ldr d5, [x17, #0x28]\n"
+    "ldr d6, [x17, #0x30]\n"
+    "mov v9.16b, v18.16b\n"
+    "usubl v0.8h, v0.8b, v16.8b\n"
+    "ldr d7, [x17, #0x38]\n"
+    "ldr d8, [x17, #0x40]\n"
+    "usubl v1.8h, v1.8b, v16.8b\n"
+    "usubl v2.8h, v2.8b, v16.8b\n"
+    "ldp x26, x25, [x13, #0x0]\n"
+    "ldp x24, x23, [x13, #0x10]\n"
+    "usubl v3.8h, v3.8b, v16.8b\n"
+    "usubl v4.8h, v4.8b, v16.8b\n"
+    "ldp x22, x21, [x13, #0x20]\n"
+    "ldp x20, x19, [x13, #0x30]\n"
+    "usubl v5.8h, v5.8b, v16.8b\n"
+    "usubl v6.8h, v6.8b, v16.8b\n"
+    "ldr d31, [x26, x15]\n"
+    "ldr d30, [x25, x15]\n"
+    "usubl v7.8h, v7.8b, v16.8b\n"
+    "usubl v8.8h, v8.8b, v16.8b\n"
+    "ldr d29, [x24, x15]\n"
+    "ldr d28, [x23, x15]\n"
+    "ushll v31.8h, v31.8b, #0x0\n"
+    "ushll v30.8h, v30.8b, #0x0\n"
+    "ldr d27, [x22, x15]\n"
+    "ldr d26, [x21, x15]\n"
+    "ushll v29.8h, v29.8b, #0x0\n"
+    "ushll v28.8h, v28.8b, #0x0\n"
+    "ldr d25, [x20, x15]\n"
+    "ldr d24, [x19, x15]\n"
+    "ushll v27.8h, v27.8b, #0x0\n"
+    "ushll v26.8h, v26.8b, #0x0\n"
+    "ushll v25.8h, v25.8b, #0x0\n"
+    "ushll v24.8h, v24.8b, #0x0\n"
+    "beq 2f\n"
+    "1:"  // Loop
+    "smlal v15.4s, v31.4h, v8.4h\n"
+    "smlal2 v18.4s, v31.8h, v8.8h\n"
+    "ldr x24, [x13, #0x40]\n"
+    "ldr x23, [x13, #0x48]\n"
+    "smlal v13.4s, v31.4h, v6.4h\n"
+    "smlal2 v17.4s, v31.8h, v6.8h\n"
+    "ldr x21, [x13, #0x50]\n"
+    "ldr x19, [x13, #0x58]\n"
+    "smlal v15.4s, v30.4h, v0.4h\n"
+    "smlal2 v18.4s, v30.8h, v0.8h\n"
+    "ldr x22, [x13, #0x78]\n"
+    "ldr x20, [x13, #0x60]\n"
+    "smlal v13.4s, v28.4h, v1.4h\n"
+    "smlal2 v17.4s, v28.8h, v1.8h\n"
+    "ldr d28, [x23, x15]\n"
+    "ushll v28.8h, v28.8b, #0x0\n"
+    "smlal v15.4s, v29.4h, v1.4h\n"
+    "smlal2 v18.4s, v29.8h, v1.8h\n"
+    "ldr d29, [x24, x15]\n"
+    "ushll v29.8h, v29.8b, #0x0\n"
+    "smlal v13.4s, v27.4h, v2.4h\n"
+    "smlal2 v17.4s, v27.8h, v2.8h\n"
+    "ldr d27, [x21, x15]\n"
+    "ushll v27.8h, v27.8b, #0x0\n"
+    "smlal v15.4s, v26.4h, v3.4h\n"
+    "smlal2 v18.4s, v26.8h, v3.8h\n"
+    "ldr d26, [x19, x15]\n"
+    "ushll v26.8h, v26.8b, #0x0\n"
+    "smlal v13.4s, v24.4h, v0.4h\n"
+    "smlal2 v17.4s, v24.8h, v0.8h\n"
+    "ldr x21, [x13, #0x80]\n"
+    "ldr x19, [x13, #0x68]\n"
+    "smlal v15.4s, v25.4h, v4.4h\n"
+    "smlal2 v18.4s, v25.8h, v4.8h\n"
+    "ldr d25, [x20, x15]\n"
+    "ushll v25.8h, v25.8b, #0x0\n"
+    "smlal v13.4s, v29.4h, v4.4h\n"
+    "smlal2 v17.4s, v29.8h, v4.8h\n"
+    "ldr x20, [x13, #0x88]\n"
+    "ldr d29, [x19, x15]\n"
+    "smlal v15.4s, v24.4h, v2.4h\n"
+    "smlal2 v18.4s, v24.8h, v2.8h\n"
+    "ldr x19, [x13, #0x70]\n"
+    "ushll v29.8h, v29.8b, #0x0\n"
+    "smlal v13.4s, v28.4h, v5.4h\n"
+    "smlal2 v17.4s, v28.8h, v5.8h\n"
+    "ldr d28, [x21, x15]\n"
+    "ushll v28.8h, v28.8b, #0x0\n"
+    "smlal v11.4s, v31.4h, v2.4h\n"
+    "smlal2 v10.4s, v31.8h, v2.8h\n"
+    "ldr x24, [x13, #0x98]\n"
+    "ldr d24, [x19, x15]\n"
+    "smlal v15.4s, v27.4h, v5.4h\n"
+    "smlal2 v18.4s, v27.8h, v5.8h\n"
+    "ushll v24.8h, v24.8b, #0x0\n"
+    "ldr x23, [x13, #0x90]\n"
+    "smlal v13.4s, v27.4h, v3.4h\n"
+    "smlal2 v17.4s, v27.8h, v3.8h\n"
+    "ldr d27, [x22, x15]\n"
+    "ushll v27.8h, v27.8b, #0x0\n"
+    "smlal v23.4s, v31.4h, v0.4h\n"
+    "smlal v11.4s, v26.4h, v3.4h\n"
+    "ldr x22, [x13, #0xa8]\n"
+    "ldr x19, [x13, #0xa0]\n"
+    "smlal2 v10.4s, v26.8h, v3.8h\n"
+    "smlal2 v9.4s, v31.8h, v0.8h\n"
+    "ldr d26, [x20, x15]\n"
+    "ushll v26.8h, v26.8b, #0x0\n"
+    "smlal v23.4s, v27.4h, v4.4h\n"
+    "smlal v11.4s, v25.4h, v0.4h\n"
+    "ldr x21, [x13, #0xb0]\n"
+    "ldr x20, [x13, #0xb8]\n"
+    "smlal2 v10.4s, v25.8h, v0.8h\n"
+    "smlal2 v9.4s, v27.8h, v4.8h\n"
+    "ldr d27, [x19, x15]\n"
+    "ushll v27.8h, v27.8b, #0x0\n"
+    "smlal v23.4s, v28.4h, v1.4h\n"
+    "smlal v15.4s, v25.4h, v6.4h\n"
+    "ldr x19, [x13, #0xc0]\n"
+    "ldr q22, [x12, #0x0]\n"
+    "smlal2 v18.4s, v25.8h, v6.8h\n"
+    "smlal v11.4s, v29.4h, v4.4h\n"
+    "ldr d25, [x23, x15]\n"
+    "ushll v25.8h, v25.8b, #0x0\n"
+    "smlal2 v10.4s, v29.8h, v4.8h\n"
+    "ldr d29, [x24, x15]\n"
+    "smlal2 v9.4s, v28.8h, v1.8h\n"
+    "ushll v29.8h, v29.8b, #0x0\n"
+    "smlal v23.4s, v26.4h, v5.4h\n"
+    "smlal v15.4s, v24.4h, v7.4h\n"
+    "ldr q31, [x11, #0x0]\n"
+    "ldr q19, [x12, #0x10]\n"
+    "smlal2 v18.4s, v24.8h, v7.8h\n"
+    "smlal v11.4s, v24.4h, v1.4h\n"
+    "sqdmulh v15.4s, v15.4s, v22.4s\n"
+    "ldr q30, [x11, #0x10]\n"
+    "smlal2 v10.4s, v24.8h, v1.8h\n"
+    "ldr d24, [x22, x15]\n"
+    "smlal2 v9.4s, v26.8h, v5.8h\n"
+    "ushll v24.8h, v24.8b, #0x0\n"
+    "smlal v23.4s, v29.4h, v2.4h\n"
+    "ldr d26, [x21, x15]\n"
+    "smlal2 v9.4s, v29.8h, v2.8h\n"
+    "ushll v26.8h, v26.8b, #0x0\n"
+    "smlal v11.4s, v25.4h, v6.4h\n"
+    "smlal v23.4s, v24.4h, v3.4h\n"
+    "and v4.16b, v15.16b, v31.16b\n"
+    "add x17, x17, #0x48\n"
+    "smlal v13.4s, v28.4h, v7.4h\n"
+    "smlal2 v17.4s, v28.8h, v7.8h\n"
+    "sqdmulh v18.4s, v18.4s, v19.4s\n"
+    "subs x16, x16, #0x1\n"
+    "smlal2 v10.4s, v25.8h, v6.8h\n"
+    "ldr d25, [x20, x15]\n"
+    "smlal2 v9.4s, v24.8h, v3.8h\n"
+    "ushll v25.8h, v25.8b, #0x0\n"
+    "smlal v11.4s, v27.4h, v7.4h\n"
+    "smlal v23.4s, v26.4h, v7.4h\n"
+    "sshr v4.4s, v4.4s, #0x1f\n"
+    "add x12, x12, #0x20\n"
+    "smlal v13.4s, v29.4h, v8.4h\n"
+    "smlal2 v17.4s, v29.8h, v8.8h\n"
+    "ldr d29, [x19, x15]\n"
+    "ushll v29.8h, v29.8b, #0x0\n"
+    "smlal2 v10.4s, v27.8h, v7.8h\n"
+    "smlal2 v9.4s, v26.8h, v7.8h\n"
+    "sqdmulh v13.4s, v13.4s, v22.4s\n"
+    "add x15, x15, #0x8\n"
+    "smlal v11.4s, v24.4h, v5.4h\n"
+    "smlal v23.4s, v25.4h, v6.4h\n"
+    "and v1.16b, v13.16b, v31.16b\n"
+    "add x11, x11, #0x20\n"
+    "smlal2 v10.4s, v24.8h, v5.8h\n"
+    "smlal2 v9.4s, v25.8h, v6.8h\n"
+    "sqdmulh v17.4s, v17.4s, v19.4s\n"
+    "smlal v11.4s, v25.4h, v8.4h\n"
+    "smlal v23.4s, v29.4h, v8.4h\n"
+    "sqdmulh v11.4s, v11.4s, v22.4s\n"
+    "smlal2 v10.4s, v25.8h, v8.8h\n"
+    "smlal2 v9.4s, v29.8h, v8.8h\n"
+    "sqdmulh v23.4s, v23.4s, v22.4s\n"
+    "and v22.16b, v11.16b, v31.16b\n"
+    "sqdmulh v10.4s, v10.4s, v19.4s\n"
+    "and v20.16b, v23.16b, v31.16b\n"
+    "sqdmulh v9.4s, v9.4s, v19.4s\n"
+    "and v19.16b, v18.16b, v30.16b\n"
+    "sshr v1.4s, v1.4s, #0x1f\n"
+    "and v27.16b, v17.16b, v30.16b\n"
+    "sshr v22.4s, v22.4s, #0x1f\n"
+    "and v25.16b, v10.16b, v30.16b\n"
+    "sshr v20.4s, v20.4s, #0x1f\n"
+    "and v0.16b, v9.16b, v30.16b\n"
+    "sqadd v15.4s, v15.4s, v4.4s\n"
+    "sshr v19.4s, v19.4s, #0x1f\n"
+    "sqadd v13.4s, v13.4s, v1.4s\n"
+    "sshr v27.4s, v27.4s, #0x1f\n"
+    "sqadd v11.4s, v11.4s, v22.4s\n"
+    "sshr v25.4s, v25.4s, #0x1f\n"
+    "sqadd v23.4s, v23.4s, v20.4s\n"
+    "sshr v0.4s, v0.4s, #0x1f\n"
+    "srshl v15.4s, v15.4s, v31.4s\n"
+    "sqadd v18.4s, v18.4s, v19.4s\n"
+    "srshl v13.4s, v13.4s, v31.4s\n"
+    "sqadd v17.4s, v17.4s, v27.4s\n"
+    "srshl v11.4s, v11.4s, v31.4s\n"
+    "sqadd v10.4s, v10.4s, v25.4s\n"
+    "srshl v23.4s, v23.4s, v31.4s\n"
+    "sqadd v9.4s, v9.4s, v0.4s\n"
+    "srshl v18.4s, v18.4s, v30.4s\n"
+    "sqxtn v15.4h, v15.4s\n"
+    "srshl v17.4s, v17.4s, v30.4s\n"
+    "sqxtn v13.4h, v13.4s\n"
+    "srshl v10.4s, v10.4s, v30.4s\n"
+    "sqxtn v11.4h, v11.4s\n"
+    "srshl v9.4s, v9.4s, v30.4s\n"
+    "sqxtn v23.4h, v23.4s\n"
+    "sqxtn2 v15.8h, v18.4s\n"
+    "sqxtn2 v13.8h, v17.4s\n"
+    "sqxtn2 v11.8h, v10.4s\n"
+    "sqxtn2 v23.8h, v9.4s\n"
+    "sqadd v15.8h, v15.8h, v12.8h\n"
+    "sqadd v13.8h, v13.8h, v12.8h\n"
+    "sqadd v11.8h, v11.8h, v12.8h\n"
+    "sqadd v23.8h, v23.8h, v12.8h\n"
+    "smax v15.8h, v15.8h, v14.8h\n"
+    "smax v13.8h, v13.8h, v14.8h\n"
+    "smax v11.8h, v11.8h, v14.8h\n"
+    "smax v23.8h, v23.8h, v14.8h\n"
+    "smin v15.8h, v15.8h, v21.8h\n"
+    "smin v13.8h, v13.8h, v21.8h\n"
+    "smin v11.8h, v11.8h, v21.8h\n"
+    "smin v23.8h, v23.8h, v21.8h\n"
+    "uzp1 v15.16b, v15.16b, v15.16b\n"
+    "str d15, [x10, x14]\n"
+    "uzp1 v13.16b, v13.16b, v13.16b\n"
+    "uzp1 v11.16b, v11.16b, v11.16b\n"
+    "str d13, [x9, x14]\n"
+    "uzp1 v23.16b, v23.16b, v23.16b\n"
+    "str d11, [x28, x14]\n"
+    "str d23, [x27, x14]\n"
+    "ldr x19, [%x[params], %[offsetof_Params_bias]]\n"
+    "ldr q15, [x19, #0x0]\n"
+    "add x14, x14, #0x8\n"
+    "ldr q18, [x19, #0x10]\n"
+    "add x19, x19, #0x20\n"
+    "str x19, [%x[params], %[offsetof_Params_bias]]\n"
+    "ldr d0, [x17, #0x0]\n"
+    "ldr d1, [x17, #0x8]\n"
+    "ldr d2, [x17, #0x10]\n"
+    "mov v13.16b, v15.16b\n"
+    "mov v17.16b, v18.16b\n"
+    "ldr d3, [x17, #0x18]\n"
+    "ldr d4, [x17, #0x20]\n"
+    "mov v11.16b, v15.16b\n"
+    "mov v10.16b, v18.16b\n"
+    "ldr d5, [x17, #0x28]\n"
+    "ldr d6, [x17, #0x30]\n"
+    "mov v23.16b, v15.16b\n"
+    "mov v9.16b, v18.16b\n"
+    "ldr d7, [x17, #0x38]\n"
+    "ldr d8, [x17, #0x40]\n"
+    "usubl v0.8h, v0.8b, v16.8b\n"
+    "usubl v1.8h, v1.8b, v16.8b\n"
+    "ldp x26, x25, [x13, #0x0]\n"
+    "ldp x24, x23, [x13, #0x10]\n"
+    "usubl v2.8h, v2.8b, v16.8b\n"
+    "usubl v3.8h, v3.8b, v16.8b\n"
+    "ldp x22, x21, [x13, #0x20]\n"
+    "ldp x20, x19, [x13, #0x30]\n"
+    "usubl v4.8h, v4.8b, v16.8b\n"
+    "usubl v5.8h, v5.8b, v16.8b\n"
+    "ldr d31, [x26, x15]\n"
+    "ldr d30, [x25, x15]\n"
+    "usubl v6.8h, v6.8b, v16.8b\n"
+    "usubl v7.8h, v7.8b, v16.8b\n"
+    "ldr d29, [x24, x15]\n"
+    "ldr d28, [x23, x15]\n"
+    "usubl v8.8h, v8.8b, v16.8b\n"
+    "ushll v31.8h, v31.8b, #0x0\n"
+    "ldr d27, [x22, x15]\n"
+    "ldr d26, [x21, x15]\n"
+    "ushll v30.8h, v30.8b, #0x0\n"
+    "ushll v29.8h, v29.8b, #0x0\n"
+    "ldr d25, [x20, x15]\n"
+    "ldr d24, [x19, x15]\n"
+    "ushll v28.8h, v28.8b, #0x0\n"
+    "ushll v27.8h, v27.8b, #0x0\n"
+    "ushll v26.8h, v26.8b, #0x0\n"
+    "ushll v25.8h, v25.8b, #0x0\n"
+    "ushll v24.8h, v24.8b, #0x0\n"
+    "bgt 1b\n"
+    "2:"  // Tail
+    "smlal v15.4s, v31.4h, v8.4h\n"
+    "smlal2 v18.4s, v31.8h, v8.8h\n"
+    "ldr x24, [x13, #0x40]\n"
+    "ldr x23, [x13, #0x48]\n"
+    "smlal v13.4s, v31.4h, v6.4h\n"
+    "smlal2 v17.4s, v31.8h, v6.8h\n"
+    "ldr x21, [x13, #0x50]\n"
+    "ldr x19, [x13, #0x58]\n"
+    "smlal v15.4s, v30.4h, v0.4h\n"
+    "smlal2 v18.4s, v30.8h, v0.8h\n"
+    "ldr x22, [x13, #0x78]\n"
+    "ldr x20, [x13, #0x60]\n"
+    "smlal v13.4s, v28.4h, v1.4h\n"
+    "smlal2 v17.4s, v28.8h, v1.8h\n"
+    "ldr d28, [x23, x15]\n"
+    "ushll v28.8h, v28.8b, #0x0\n"
+    "smlal v15.4s, v29.4h, v1.4h\n"
+    "smlal2 v18.4s, v29.8h, v1.8h\n"
+    "ldr d29, [x24, x15]\n"
+    "ushll v29.8h, v29.8b, #0x0\n"
+    "smlal v13.4s, v27.4h, v2.4h\n"
+    "smlal2 v17.4s, v27.8h, v2.8h\n"
+    "ldr d27, [x21, x15]\n"
+    "ushll v27.8h, v27.8b, #0x0\n"
+    "smlal v15.4s, v26.4h, v3.4h\n"
+    "smlal2 v18.4s, v26.8h, v3.8h\n"
+    "ldr d26, [x19, x15]\n"
+    "ushll v26.8h, v26.8b, #0x0\n"
+    "smlal v13.4s, v24.4h, v0.4h\n"
+    "smlal2 v17.4s, v24.8h, v0.8h\n"
+    "ldr x21, [x13, #0x80]\n"
+    "ldr x19, [x13, #0x68]\n"
+    "smlal v15.4s, v25.4h, v4.4h\n"
+    "smlal2 v18.4s, v25.8h, v4.8h\n"
+    "ldr d25, [x20, x15]\n"
+    "ushll v25.8h, v25.8b, #0x0\n"
+    "smlal v13.4s, v29.4h, v4.4h\n"
+    "smlal2 v17.4s, v29.8h, v4.8h\n"
+    "ldr x20, [x13, #0x88]\n"
+    "ldr d29, [x19, x15]\n"
+    "smlal v15.4s, v24.4h, v2.4h\n"
+    "smlal2 v18.4s, v24.8h, v2.8h\n"
+    "ldr x19, [x13, #0x70]\n"
+    "ushll v29.8h, v29.8b, #0x0\n"
+    "smlal v13.4s, v28.4h, v5.4h\n"
+    "smlal2 v17.4s, v28.8h, v5.8h\n"
+    "ldr d28, [x21, x15]\n"
+    "ushll v28.8h, v28.8b, #0x0\n"
+    "smlal v11.4s, v31.4h, v2.4h\n"
+    "smlal2 v10.4s, v31.8h, v2.8h\n"
+    "ldr x24, [x13, #0x98]\n"
+    "ldr d24, [x19, x15]\n"
+    "smlal v15.4s, v27.4h, v5.4h\n"
+    "smlal2 v18.4s, v27.8h, v5.8h\n"
+    "ushll v24.8h, v24.8b, #0x0\n"
+    "ldr x23, [x13, #0x90]\n"
+    "smlal v13.4s, v27.4h, v3.4h\n"
+    "smlal2 v17.4s, v27.8h, v3.8h\n"
+    "ldr d27, [x22, x15]\n"
+    "ushll v27.8h, v27.8b, #0x0\n"
+    "smlal v23.4s, v31.4h, v0.4h\n"
+    "smlal v11.4s, v26.4h, v3.4h\n"
+    "ldr x22, [x13, #0xa8]\n"
+    "ldr x19, [x13, #0xa0]\n"
+    "smlal2 v10.4s, v26.8h, v3.8h\n"
+    "smlal2 v9.4s, v31.8h, v0.8h\n"
+    "ldr d26, [x20, x15]\n"
+    "ushll v26.8h, v26.8b, #0x0\n"
+    "smlal v23.4s, v27.4h, v4.4h\n"
+    "smlal v11.4s, v25.4h, v0.4h\n"
+    "ldr x21, [x13, #0xb0]\n"
+    "ldr x20, [x13, #0xb8]\n"
+    "smlal2 v10.4s, v25.8h, v0.8h\n"
+    "smlal2 v9.4s, v27.8h, v4.8h\n"
+    "ldr d27, [x19, x15]\n"
+    "ushll v27.8h, v27.8b, #0x0\n"
+    "smlal v23.4s, v28.4h, v1.4h\n"
+    "smlal v15.4s, v25.4h, v6.4h\n"
+    "ldr x19, [x13, #0xc0]\n"
+    "ldr q22, [x12, #0x0]\n"
+    "smlal2 v18.4s, v25.8h, v6.8h\n"
+    "smlal v11.4s, v29.4h, v4.4h\n"
+    "ldr d25, [x23, x15]\n"
+    "ushll v25.8h, v25.8b, #0x0\n"
+    "smlal2 v10.4s, v29.8h, v4.8h\n"
+    "ldr d29, [x24, x15]\n"
+    "smlal2 v9.4s, v28.8h, v1.8h\n"
+    "ushll v29.8h, v29.8b, #0x0\n"
+    "smlal v23.4s, v26.4h, v5.4h\n"
+    "smlal v15.4s, v24.4h, v7.4h\n"
+    "ldr q31, [x11, #0x0]\n"
+    "ldr q19, [x12, #0x10]\n"
+    "smlal2 v18.4s, v24.8h, v7.8h\n"
+    "smlal v11.4s, v24.4h, v1.4h\n"
+    "sqdmulh v15.4s, v15.4s, v22.4s\n"
+    "ldr q30, [x11, #0x10]\n"
+    "smlal2 v10.4s, v24.8h, v1.8h\n"
+    "ldr d24, [x22, x15]\n"
+    "smlal2 v9.4s, v26.8h, v5.8h\n"
+    "ushll v24.8h, v24.8b, #0x0\n"
+    "smlal v23.4s, v29.4h, v2.4h\n"
+    "ldr d26, [x21, x15]\n"
+    "smlal2 v9.4s, v29.8h, v2.8h\n"
+    "ushll v26.8h, v26.8b, #0x0\n"
+    "smlal v11.4s, v25.4h, v6.4h\n"
+    "smlal v23.4s, v24.4h, v3.4h\n"
+    "and v4.16b, v15.16b, v31.16b\n"
+    "tst x8, #0x7\n"
+    "smlal v13.4s, v28.4h, v7.4h\n"
+    "smlal2 v17.4s, v28.8h, v7.8h\n"
+    "sqdmulh v18.4s, v18.4s, v19.4s\n"
+    "add x12, x12, #0x20\n"
+    "smlal2 v10.4s, v25.8h, v6.8h\n"
+    "ldr d25, [x20, x15]\n"
+    "smlal2 v9.4s, v24.8h, v3.8h\n"
+    "ushll v25.8h, v25.8b, #0x0\n"
+    "smlal v11.4s, v27.4h, v7.4h\n"
+    "smlal v23.4s, v26.4h, v7.4h\n"
+    "sshr v4.4s, v4.4s, #0x1f\n"
+    "add x11, x11, #0x20\n"
+    "smlal v13.4s, v29.4h, v8.4h\n"
+    "smlal2 v17.4s, v29.8h, v8.8h\n"
+    "ldr d29, [x19, x15]\n"
+    "ushll v29.8h, v29.8b, #0x0\n"
+    "smlal2 v10.4s, v27.8h, v7.8h\n"
+    "smlal2 v9.4s, v26.8h, v7.8h\n"
+    "sqdmulh v13.4s, v13.4s, v22.4s\n"
+    "add x15, x15, #0x8\n"
+    "smlal v11.4s, v24.4h, v5.4h\n"
+    "smlal v23.4s, v25.4h, v6.4h\n"
+    "and v1.16b, v13.16b, v31.16b\n"
+    "smlal2 v10.4s, v24.8h, v5.8h\n"
+    "smlal2 v9.4s, v25.8h, v6.8h\n"
+    "sqdmulh v17.4s, v17.4s, v19.4s\n"
+    "smlal v11.4s, v25.4h, v8.4h\n"
+    "smlal v23.4s, v29.4h, v8.4h\n"
+    "sqdmulh v11.4s, v11.4s, v22.4s\n"
+    "smlal2 v10.4s, v25.8h, v8.8h\n"
+    "smlal2 v9.4s, v29.8h, v8.8h\n"
+    "sqdmulh v23.4s, v23.4s, v22.4s\n"
+    "and v22.16b, v11.16b, v31.16b\n"
+    "sqdmulh v10.4s, v10.4s, v19.4s\n"
+    "and v20.16b, v23.16b, v31.16b\n"
+    "sqdmulh v9.4s, v9.4s, v19.4s\n"
+    "and v19.16b, v18.16b, v30.16b\n"
+    "sshr v1.4s, v1.4s, #0x1f\n"
+    "and v27.16b, v17.16b, v30.16b\n"
+    "sshr v22.4s, v22.4s, #0x1f\n"
+    "and v25.16b, v10.16b, v30.16b\n"
+    "sshr v20.4s, v20.4s, #0x1f\n"
+    "and v0.16b, v9.16b, v30.16b\n"
+    "sqadd v15.4s, v15.4s, v4.4s\n"
+    "sshr v19.4s, v19.4s, #0x1f\n"
+    "sqadd v13.4s, v13.4s, v1.4s\n"
+    "sshr v27.4s, v27.4s, #0x1f\n"
+    "sqadd v11.4s, v11.4s, v22.4s\n"
+    "sshr v25.4s, v25.4s, #0x1f\n"
+    "sqadd v23.4s, v23.4s, v20.4s\n"
+    "sshr v0.4s, v0.4s, #0x1f\n"
+    "srshl v15.4s, v15.4s, v31.4s\n"
+    "sqadd v18.4s, v18.4s, v19.4s\n"
+    "srshl v13.4s, v13.4s, v31.4s\n"
+    "sqadd v17.4s, v17.4s, v27.4s\n"
+    "srshl v11.4s, v11.4s, v31.4s\n"
+    "sqadd v10.4s, v10.4s, v25.4s\n"
+    "srshl v23.4s, v23.4s, v31.4s\n"
+    "sqadd v9.4s, v9.4s, v0.4s\n"
+    "srshl v18.4s, v18.4s, v30.4s\n"
+    "sqxtn v15.4h, v15.4s\n"
+    "srshl v17.4s, v17.4s, v30.4s\n"
+    "sqxtn v13.4h, v13.4s\n"
+    "srshl v10.4s, v10.4s, v30.4s\n"
+    "sqxtn v11.4h, v11.4s\n"
+    "srshl v9.4s, v9.4s, v30.4s\n"
+    "sqxtn v23.4h, v23.4s\n"
+    "sqxtn2 v15.8h, v18.4s\n"
+    "sqxtn2 v13.8h, v17.4s\n"
+    "sqxtn2 v11.8h, v10.4s\n"
+    "sqxtn2 v23.8h, v9.4s\n"
+    "sqadd v15.8h, v15.8h, v12.8h\n"
+    "sqadd v13.8h, v13.8h, v12.8h\n"
+    "sqadd v11.8h, v11.8h, v12.8h\n"
+    "sqadd v23.8h, v23.8h, v12.8h\n"
+    "smax v15.8h, v15.8h, v14.8h\n"
+    "smax v13.8h, v13.8h, v14.8h\n"
+    "smax v11.8h, v11.8h, v14.8h\n"
+    "smax v23.8h, v23.8h, v14.8h\n"
+    "smin v15.8h, v15.8h, v21.8h\n"
+    "smin v13.8h, v13.8h, v21.8h\n"
+    "smin v11.8h, v11.8h, v21.8h\n"
+    "smin v23.8h, v23.8h, v21.8h\n"
+    "uzp1 v15.16b, v15.16b, v15.16b\n"
+    "str d15, [x10, x14]\n"
+    "uzp1 v13.16b, v13.16b, v13.16b\n"
+    "uzp1 v11.16b, v11.16b, v11.16b\n"
+    "str d13, [x9, x14]\n"
+    "uzp1 v23.16b, v23.16b, v23.16b\n"
+    "str d11, [x28, x14]\n"
+    "str d23, [x27, x14]\n"
+    "add x14, x14, #0x8\n"
+    "beq 88f\n"
+    "add x17, x17, #0x48\n"
+    "3:"  // Oddments
+    "ldr x19, [%x[params], %[offsetof_Params_bias]]\n"
+    "tbz x8, #2, 5f\n"
+    "ld1 { v15.4s }, [x19], #0x10\n"
+    "tbz x8, #1, 4f\n"
+    "ld1 { v18.d }[0], [x19], #0x8\n"
+    "tbz x8, #0, 7f\n"
+    "ld1 { v18.s }[2], [x19]\n"
+    "b 7f\n"
+    "4:"  // Oddments: Load bias: Bit 2: Bit 1: Unset
+    "tbz x8, #0, 7f\n"
+    "ld1 { v18.s }[0], [x19]\n"
+    "b 7f\n"
+    "5:"  // Oddments: Load bias: Bit 2: Unset
+    "tbz x8, #1, 6f\n"
+    "ld1 { v15.d }[0], [x19], #0x8\n"
+    "tbz x8, #0, 7f\n"
+    "ld1 { v15.s }[2], [x19]\n"
+    "b 7f\n"
+    "6:"  // Oddments: Load bias: Bit 2: Unset: Bit 1: Unset
+    "tbz x8, #0, 7f\n"
+    "ld1 { v15.s }[0], [x19]\n"
+    "7:"  // Oddments: Load bias: Bit 2: End
+    "ldr d0, [x17, #0x0]\n"
+    "ldr d1, [x17, #0x8]\n"
+    "mov v13.16b, v15.16b\n"
+    "mov v17.16b, v18.16b\n"
+    "ldr d2, [x17, #0x10]\n"
+    "ldr d3, [x17, #0x18]\n"
+    "mov v11.16b, v15.16b\n"
+    "mov v10.16b, v18.16b\n"
+    "ldr d4, [x17, #0x20]\n"
+    "ldr d5, [x17, #0x28]\n"
+    "mov v23.16b, v15.16b\n"
+    "mov v9.16b, v18.16b\n"
+    "ldr d6, [x17, #0x30]\n"
+    "ldr d7, [x17, #0x38]\n"
+    "usubl v0.8h, v0.8b, v16.8b\n"
+    "usubl v1.8h, v1.8b, v16.8b\n"
+    "ldr d8, [x17, #0x40]\n"
+    "ldp x26, x25, [x13, #0x0]\n"
+    "usubl v2.8h, v2.8b, v16.8b\n"
+    "usubl v3.8h, v3.8b, v16.8b\n"
+    "ldp x24, x23, [x13, #0x10]\n"
+    "ldp x22, x21, [x13, #0x20]\n"
+    "usubl v4.8h, v4.8b, v16.8b\n"
+    "usubl v5.8h, v5.8b, v16.8b\n"
+    "ldp x20, x19, [x13, #0x30]\n"
+    "usubl v6.8h, v6.8b, v16.8b\n"
+    "usubl v7.8h, v7.8b, v16.8b\n"
+    "usubl v8.8h, v8.8b, v16.8b\n"
+    "add x26, x26, x15\n"
+    "add x25, x25, x15\n"
+    "add x24, x24, x15\n"
+    "add x23, x23, x15\n"
+    "add x22, x22, x15\n"
+    "add x21, x21, x15\n"
+    "add x20, x20, x15\n"
+    "add x19, x19, x15\n"
+    "tbz x8, #2, 9f\n"
+    "ld1 { v31.s }[0], [x26], #0x4\n"
+    "ld1 { v30.s }[0], [x25], #0x4\n"
+    "ld1 { v29.s }[0], [x24], #0x4\n"
+    "ld1 { v28.s }[0], [x23], #0x4\n"
+    "ld1 { v27.s }[0], [x22], #0x4\n"
+    "ld1 { v26.s }[0], [x21], #0x4\n"
+    "ld1 { v25.s }[0], [x20], #0x4\n"
+    "ld1 { v24.s }[0], [x19], #0x4\n"
+    "tbz x8, #1, 8f\n"
+    "ld1 { v31.h }[2], [x26], #0x2\n"
+    "ld1 { v30.h }[2], [x25], #0x2\n"
+    "ld1 { v29.h }[2], [x24], #0x2\n"
+    "ld1 { v28.h }[2], [x23], #0x2\n"
+    "ld1 { v27.h }[2], [x22], #0x2\n"
+    "ld1 { v26.h }[2], [x21], #0x2\n"
+    "ld1 { v25.h }[2], [x20], #0x2\n"
+    "ld1 { v24.h }[2], [x19], #0x2\n"
+    "tbz x8, #0, 11f\n"
+    "ld1 { v31.b }[6], [x26]\n"
+    "ld1 { v30.b }[6], [x25]\n"
+    "ld1 { v29.b }[6], [x24]\n"
+    "ld1 { v28.b }[6], [x23]\n"
+    "ld1 { v27.b }[6], [x22]\n"
+    "ld1 { v26.b }[6], [x21]\n"
+    "ld1 { v25.b }[6], [x20]\n"
+    "ld1 { v24.b }[6], [x19]\n"
+    "b 11f\n"
+    "8:"  // Oddments: Initial loads: Bit 2: Bit 1: Unset
+    "tbz x8, #0, 11f\n"
+    "ld1 { v31.b }[4], [x26]\n"
+    "ld1 { v30.b }[4], [x25]\n"
+    "ld1 { v29.b }[4], [x24]\n"
+    "ld1 { v28.b }[4], [x23]\n"
+    "ld1 { v27.b }[4], [x22]\n"
+    "ld1 { v26.b }[4], [x21]\n"
+    "ld1 { v25.b }[4], [x20]\n"
+    "ld1 { v24.b }[4], [x19]\n"
+    "b 11f\n"
+    "9:"  // Oddments: Initial loads: Bit 2: Unset
+    "tbz x8, #1, 10f\n"
+    "ld1 { v31.h }[0], [x26], #0x2\n"
+    "ld1 { v30.h }[0], [x25], #0x2\n"
+    "ld1 { v29.h }[0], [x24], #0x2\n"
+    "ld1 { v28.h }[0], [x23], #0x2\n"
+    "ld1 { v27.h }[0], [x22], #0x2\n"
+    "ld1 { v26.h }[0], [x21], #0x2\n"
+    "ld1 { v25.h }[0], [x20], #0x2\n"
+    "ld1 { v24.h }[0], [x19], #0x2\n"
+    "tbz x8, #0, 11f\n"
+    "ld1 { v31.b }[2], [x26]\n"
+    "ld1 { v30.b }[2], [x25]\n"
+    "ld1 { v29.b }[2], [x24]\n"
+    "ld1 { v28.b }[2], [x23]\n"
+    "ld1 { v27.b }[2], [x22]\n"
+    "ld1 { v26.b }[2], [x21]\n"
+    "ld1 { v25.b }[2], [x20]\n"
+    "ld1 { v24.b }[2], [x19]\n"
+    "b 11f\n"
+    "10:"  // Oddments: Initial loads: Bit 2: Unset: Bit 1: Unset
+    "tbz x8, #0, 11f\n"
+    "ld1 { v31.b }[0], [x26]\n"
+    "ld1 { v30.b }[0], [x25]\n"
+    "ld1 { v29.b }[0], [x24]\n"
+    "ld1 { v28.b }[0], [x23]\n"
+    "ld1 { v27.b }[0], [x22]\n"
+    "ld1 { v26.b }[0], [x21]\n"
+    "ld1 { v25.b }[0], [x20]\n"
+    "ld1 { v24.b }[0], [x19]\n"
+    "11:"  // Oddments: Initial loads: Bit 2: End
+    "ushll v31.8h, v31.8b, #0x0\n"
+    "smlal v15.4s, v31.4h, v8.4h\n"
+    "smlal2 v18.4s, v31.8h, v8.8h\n"
+    "ldr x24, [x13, #0x40]\n"
+    "ushll v30.8h, v30.8b, #0x0\n"
+    "smlal v15.4s, v30.4h, v0.4h\n"
+    "smlal2 v18.4s, v30.8h, v0.8h\n"
+    "add x24, x24, x15\n"
+    "ushll v29.8h, v29.8b, #0x0\n"
+    "smlal v13.4s, v31.4h, v6.4h\n"
+    "smlal2 v17.4s, v31.8h, v6.8h\n"
+    "smlal v15.4s, v29.4h, v1.4h\n"
+    "smlal2 v18.4s, v29.8h, v1.8h\n"
+    "ushll v28.8h, v28.8b, #0x0\n"
+    "ushll v26.8h, v26.8b, #0x0\n"
+    "smlal v13.4s, v28.4h, v1.4h\n"
+    "smlal2 v17.4s, v28.8h, v1.8h\n"
+    "smlal v15.4s, v26.4h, v3.4h\n"
+    "smlal2 v18.4s, v26.8h, v3.8h\n"
+    "ushll v27.8h, v27.8b, #0x0\n"
+    "ushll v25.8h, v25.8b, #0x0\n"
+    "smlal v13.4s, v27.4h, v2.4h\n"
+    "smlal2 v17.4s, v27.8h, v2.8h\n"
+    "smlal v15.4s, v25.4h, v4.4h\n"
+    "smlal2 v18.4s, v25.8h, v4.8h\n"
+    "ushll v24.8h, v24.8b, #0x0\n"
+    "smlal v11.4s, v31.4h, v2.4h\n"
+    "smlal2 v10.4s, v31.8h, v2.8h\n"
+    "smlal v23.4s, v31.4h, v0.4h\n"
+    "smlal2 v9.4s, v31.8h, v0.8h\n"
+    "smlal v15.4s, v24.4h, v2.4h\n"
+    "smlal2 v18.4s, v24.8h, v2.8h\n"
+    "smlal v13.4s, v24.4h, v0.4h\n"
+    "smlal2 v17.4s, v24.8h, v0.8h\n"
+    "tbz x8, #2, 13f\n"
+    "ld1 { v29.s }[0], [x24], #0x4\n"
+    "tbz x8, #1, 12f\n"
+    "ld1 { v29.h }[2], [x24], #0x2\n"
+    "tbz x8, #0, 15f\n"
+    "ld1 { v29.b }[6], [x24]\n"
+    "b 15f\n"
+    "12:"  // Oddments: Load (1, 3): Bit 2: Bit 1: Unset
+    "tbz x8, #0, 15f\n"
+    "ld1 { v29.b }[4], [x24]\n"
+    "b 15f\n"
+    "13:"  // Oddments: Load (1, 3): Bit 2: Unset
+    "tbz x8, #1, 14f\n"
+    "ld1 { v29.h }[0], [x24], #0x2\n"
+    "tbz x8, #0, 15f\n"
+    "ld1 { v29.b }[2], [x24]\n"
+    "b 15f\n"
+    "14:"  // Oddments: Load (1, 3): Bit 2: Unset: Bit 1: Unset
+    "tbz x8, #0, 15f\n"
+    "ld1 { v29.b }[0], [x24]\n"
+    "15:"  // Oddments: Load (1, 3): Bit 2: End
+    "ushll v29.8h, v29.8b, #0x0\n"
+    "ldr x23, [x13, #0x48]\n"
+    "smlal v13.4s, v29.4h, v4.4h\n"
+    "smlal2 v17.4s, v29.8h, v4.8h\n"
+    "add x23, x23, x15\n"
+    "tbz x8, #2, 17f\n"
+    "ld1 { v28.s }[0], [x23], #0x4\n"
+    "tbz x8, #1, 16f\n"
+    "ld1 { v28.h }[2], [x23], #0x2\n"
+    "tbz x8, #0, 19f\n"
+    "ld1 { v28.b }[6], [x23]\n"
+    "b 19f\n"
+    "16:"  // Oddments: Load (1, 4): Bit 2: Bit 1: Unset
+    "tbz x8, #0, 19f\n"
+    "ld1 { v28.b }[4], [x23]\n"
+    "b 19f\n"
+    "17:"  // Oddments: Load (1, 4): Bit 2: Unset
+    "tbz x8, #1, 18f\n"
+    "ld1 { v28.h }[0], [x23], #0x2\n"
+    "tbz x8, #0, 19f\n"
+    "ld1 { v28.b }[2], [x23]\n"
+    "b 19f\n"
+    "18:"  // Oddments: Load (1, 4): Bit 2: Unset: Bit 1: Unset
+    "tbz x8, #0, 19f\n"
+    "ld1 { v28.b }[0], [x23]\n"
+    "19:"  // Oddments: Load (1, 4): Bit 2: End
+    "ushll v28.8h, v28.8b, #0x0\n"
+    "ldr x21, [x13, #0x50]\n"
+    "smlal v13.4s, v28.4h, v5.4h\n"
+    "smlal2 v17.4s, v28.8h, v5.8h\n"
+    "add x21, x21, x15\n"
+    "tbz x8, #2, 21f\n"
+    "ld1 { v27.s }[0], [x21], #0x4\n"
+    "tbz x8, #1, 20f\n"
+    "ld1 { v27.h }[2], [x21], #0x2\n"
+    "tbz x8, #0, 23f\n"
+    "ld1 { v27.b }[6], [x21]\n"
+    "b 23f\n"
+    "20:"  // Oddments: Load (1, 2): Bit 2: Bit 1: Unset
+    "tbz x8, #0, 23f\n"
+    "ld1 { v27.b }[4], [x21]\n"
+    "b 23f\n"
+    "21:"  // Oddments: Load (1, 2): Bit 2: Unset
+    "tbz x8, #1, 22f\n"
+    "ld1 { v27.h }[0], [x21], #0x2\n"
+    "tbz x8, #0, 23f\n"
+    "ld1 { v27.b }[2], [x21]\n"
+    "b 23f\n"
+    "22:"  // Oddments: Load (1, 2): Bit 2: Unset: Bit 1: Unset
+    "tbz x8, #0, 23f\n"
+    "ld1 { v27.b }[0], [x21]\n"
+    "23:"  // Oddments: Load (1, 2): Bit 2: End
+    "ushll v27.8h, v27.8b, #0x0\n"
+    "ldr x19, [x13, #0x58]\n"
+    "smlal v15.4s, v27.4h, v5.4h\n"
+    "smlal2 v18.4s, v27.8h, v5.8h\n"
+    "smlal v13.4s, v27.4h, v3.4h\n"
+    "smlal2 v17.4s, v27.8h, v3.8h\n"
+    "add x19, x19, x15\n"
+    "tbz x8, #2, 25f\n"
+    "ld1 { v26.s }[0], [x19], #0x4\n"
+    "tbz x8, #1, 24f\n"
+    "ld1 { v26.h }[2], [x19], #0x2\n"
+    "tbz x8, #0, 27f\n"
+    "ld1 { v26.b }[6], [x19]\n"
+    "b 27f\n"
+    "24:"  // Oddments: Load (3, 0): Bit 2: Bit 1: Unset
+    "tbz x8, #0, 27f\n"
+    "ld1 { v26.b }[4], [x19]\n"
+    "b 27f\n"
+    "25:"  // Oddments: Load (3, 0): Bit 2: Unset
+    "tbz x8, #1, 26f\n"
+    "ld1 { v26.h }[0], [x19], #0x2\n"
+    "tbz x8, #0, 27f\n"
+    "ld1 { v26.b }[2], [x19]\n"
+    "b 27f\n"
+    "26:"  // Oddments: Load (3, 0): Bit 2: Unset: Bit 1: Unset
+    "tbz x8, #0, 27f\n"
+    "ld1 { v26.b }[0], [x19]\n"
+    "27:"  // Oddments: Load (3, 0): Bit 2: End
+    "ushll v26.8h, v26.8b, #0x0\n"
+    "ldr x20, [x13, #0x60]\n"
+    "smlal v11.4s, v26.4h, v3.4h\n"
+    "smlal2 v10.4s, v26.8h, v3.8h\n"
+    "add x20, x20, x15\n"
+    "tbz x8, #2, 29f\n"
+    "ld1 { v25.s }[0], [x20], #0x4\n"
+    "tbz x8, #1, 28f\n"
+    "ld1 { v25.h }[2], [x20], #0x2\n"
+    "tbz x8, #0, 31f\n"
+    "ld1 { v25.b }[6], [x20]\n"
+    "b 31f\n"
+    "28:"  // Oddments: Load (2, 0): Bit 2: Bit 1: Unset
+    "tbz x8, #0, 31f\n"
+    "ld1 { v25.b }[4], [x20]\n"
+    "b 31f\n"
+    "29:"  // Oddments: Load (2, 0): Bit 2: Unset
+    "tbz x8, #1, 30f\n"
+    "ld1 { v25.h }[0], [x20], #0x2\n"
+    "tbz x8, #0, 31f\n"
+    "ld1 { v25.b }[2], [x20]\n"
+    "b 31f\n"
+    "30:"  // Oddments: Load (2, 0): Bit 2: Unset: Bit 1: Unset
+    "tbz x8, #0, 31f\n"
+    "ld1 { v25.b }[0], [x20]\n"
+    "31:"  // Oddments: Load (2, 0): Bit 2: End
+    "ushll v25.8h, v25.8b, #0x0\n"
+    "ldr x19, [x13, #0x68]\n"
+    "smlal v15.4s, v25.4h, v6.4h\n"
+    "smlal2 v18.4s, v25.8h, v6.8h\n"
+    "smlal v11.4s, v25.4h, v0.4h\n"
+    "smlal2 v10.4s, v25.8h, v0.8h\n"
+    "add x19, x19, x15\n"
+    "tbz x8, #2, 33f\n"
+    "ld1 { v29.s }[0], [x19], #0x4\n"
+    "tbz x8, #1, 32f\n"
+    "ld1 { v29.h }[2], [x19], #0x2\n"
+    "tbz x8, #0, 35f\n"
+    "ld1 { v29.b }[6], [x19]\n"
+    "b 35f\n"
+    "32:"  // Oddments: Load (3, 1): Bit 2: Bit 1: Unset
+    "tbz x8, #0, 35f\n"
+    "ld1 { v29.b }[4], [x19]\n"
+    "b 35f\n"
+    "33:"  // Oddments: Load (3, 1): Bit 2: Unset
+    "tbz x8, #1, 34f\n"
+    "ld1 { v29.h }[0], [x19], #0x2\n"
+    "tbz x8, #0, 35f\n"
+    "ld1 { v29.b }[2], [x19]\n"
+    "b 35f\n"
+    "34:"  // Oddments: Load (3, 1): Bit 2: Unset: Bit 1: Unset
+    "tbz x8, #0, 35f\n"
+    "ld1 { v29.b }[0], [x19]\n"
+    "35:"  // Oddments: Load (3, 1): Bit 2: End
+    "ushll v29.8h, v29.8b, #0x0\n"
+    "ldr x19, [x13, #0x70]\n"
+    "smlal v11.4s, v29.4h, v4.4h\n"
+    "smlal2 v10.4s, v29.8h, v4.8h\n"
+    "add x19, x19, x15\n"
+    "tbz x8, #2, 37f\n"
+    "ld1 { v24.s }[0], [x19], #0x4\n"
+    "tbz x8, #1, 36f\n"
+    "ld1 { v24.h }[2], [x19], #0x2\n"
+    "tbz x8, #0, 39f\n"
+    "ld1 { v24.b }[6], [x19]\n"
+    "b 39f\n"
+    "36:"  // Oddments: Load (2, 1): Bit 2: Bit 1: Unset
+    "tbz x8, #0, 39f\n"
+    "ld1 { v24.b }[4], [x19]\n"
+    "b 39f\n"
+    "37:"  // Oddments: Load (2, 1): Bit 2: Unset
+    "tbz x8, #1, 38f\n"
+    "ld1 { v24.h }[0], [x19], #0x2\n"
+    "tbz x8, #0, 39f\n"
+    "ld1 { v24.b }[2], [x19]\n"
+    "b 39f\n"
+    "38:"  // Oddments: Load (2, 1): Bit 2: Unset: Bit 1: Unset
+    "tbz x8, #0, 39f\n"
+    "ld1 { v24.b }[0], [x19]\n"
+    "39:"  // Oddments: Load (2, 1): Bit 2: End
+    "ushll v24.8h, v24.8b, #0x0\n"
+    "ldr x22, [x13, #0x78]\n"
+    "smlal v15.4s, v24.4h, v7.4h\n"
+    "smlal2 v18.4s, v24.8h, v7.8h\n"
+    "smlal v11.4s, v24.4h, v1.4h\n"
+    "smlal2 v10.4s, v24.8h, v1.8h\n"
+    "add x22, x22, x15\n"
+    "tbz x8, #2, 41f\n"
+    "ld1 { v27.s }[0], [x22], #0x4\n"
+    "tbz x8, #1, 40f\n"
+    "ld1 { v27.h }[2], [x22], #0x2\n"
+    "tbz x8, #0, 43f\n"
+    "ld1 { v27.b }[6], [x22]\n"
+    "b 43f\n"
+    "40:"  // Oddments: Load (3, 3): Bit 2: Bit 1: Unset
+    "tbz x8, #0, 43f\n"
+    "ld1 { v27.b }[4], [x22]\n"
+    "b 43f\n"
+    "41:"  // Oddments: Load (3, 3): Bit 2: Unset
+    "tbz x8, #1, 42f\n"
+    "ld1 { v27.h }[0], [x22], #0x2\n"
+    "tbz x8, #0, 43f\n"
+    "ld1 { v27.b }[2], [x22]\n"
+    "b 43f\n"
+    "42:"  // Oddments: Load (3, 3): Bit 2: Unset: Bit 1: Unset
+    "tbz x8, #0, 43f\n"
+    "ld1 { v27.b }[0], [x22]\n"
+    "43:"  // Oddments: Load (3, 3): Bit 2: End
+    "ushll v27.8h, v27.8b, #0x0\n"
+    "ldr x21, [x13, #0x80]\n"
+    "smlal v23.4s, v27.4h, v4.4h\n"
+    "smlal2 v9.4s, v27.8h, v4.8h\n"
+    "add x21, x21, x15\n"
+    "tbz x8, #2, 45f\n"
+    "ld1 { v28.s }[0], [x21], #0x4\n"
+    "tbz x8, #1, 44f\n"
+    "ld1 { v28.h }[2], [x21], #0x2\n"
+    "tbz x8, #0, 47f\n"
+    "ld1 { v28.b }[6], [x21]\n"
+    "b 47f\n"
+    "44:"  // Oddments: Load (2, 3): Bit 2: Bit 1: Unset
+    "tbz x8, #0, 47f\n"
+    "ld1 { v28.b }[4], [x21]\n"
+    "b 47f\n"
+    "45:"  // Oddments: Load (2, 3): Bit 2: Unset
+    "tbz x8, #1, 46f\n"
+    "ld1 { v28.h }[0], [x21], #0x2\n"
+    "tbz x8, #0, 47f\n"
+    "ld1 { v28.b }[2], [x21]\n"
+    "b 47f\n"
+    "46:"  // Oddments: Load (2, 3): Bit 2: Unset: Bit 1: Unset
+    "tbz x8, #0, 47f\n"
+    "ld1 { v28.b }[0], [x21]\n"
+    "47:"  // Oddments: Load (2, 3): Bit 2: End
+    "ushll v28.8h, v28.8b, #0x0\n"
+    "ldr x20, [x13, #0x88]\n"
+    "smlal v13.4s, v28.4h, v7.4h\n"
+    "smlal2 v17.4s, v28.8h, v7.8h\n"
+    "smlal v23.4s, v28.4h, v1.4h\n"
+    "smlal2 v9.4s, v28.8h, v1.8h\n"
+    "add x20, x20, x15\n"
+    "tbz x8, #2, 49f\n"
+    "ld1 { v26.s }[0], [x20], #0x4\n"
+    "tbz x8, #1, 48f\n"
+    "ld1 { v26.h }[2], [x20], #0x2\n"
+    "tbz x8, #0, 51f\n"
+    "ld1 { v26.b }[6], [x20]\n"
+    "b 51f\n"
+    "48:"  // Oddments: Load (3, 4): Bit 2: Bit 1: Unset
+    "tbz x8, #0, 51f\n"
+    "ld1 { v26.b }[4], [x20]\n"
+    "b 51f\n"
+    "49:"  // Oddments: Load (3, 4): Bit 2: Unset
+    "tbz x8, #1, 50f\n"
+    "ld1 { v26.h }[0], [x20], #0x2\n"
+    "tbz x8, #0, 51f\n"
+    "ld1 { v26.b }[2], [x20]\n"
+    "b 51f\n"
+    "50:"  // Oddments: Load (3, 4): Bit 2: Unset: Bit 1: Unset
+    "tbz x8, #0, 51f\n"
+    "ld1 { v26.b }[0], [x20]\n"
+    "51:"  // Oddments: Load (3, 4): Bit 2: End
+    "ushll v26.8h, v26.8b, #0x0\n"
+    "ldr x23, [x13, #0x90]\n"
+    "smlal v23.4s, v26.4h, v5.4h\n"
+    "smlal2 v9.4s, v26.8h, v5.8h\n"
+    "add x23, x23, x15\n"
+    "tbz x8, #2, 53f\n"
+    "ld1 { v25.s }[0], [x23], #0x4\n"
+    "tbz x8, #1, 52f\n"
+    "ld1 { v25.h }[2], [x23], #0x2\n"
+    "tbz x8, #0, 55f\n"
+    "ld1 { v25.b }[6], [x23]\n"
+    "b 55f\n"
+    "52:"  // Oddments: Load (4, 0): Bit 2: Bit 1: Unset
+    "tbz x8, #0, 55f\n"
+    "ld1 { v25.b }[4], [x23]\n"
+    "b 55f\n"
+    "53:"  // Oddments: Load (4, 0): Bit 2: Unset
+    "tbz x8, #1, 54f\n"
+    "ld1 { v25.h }[0], [x23], #0x2\n"
+    "tbz x8, #0, 55f\n"
+    "ld1 { v25.b }[2], [x23]\n"
+    "b 55f\n"
+    "54:"  // Oddments: Load (4, 0): Bit 2: Unset: Bit 1: Unset
+    "tbz x8, #0, 55f\n"
+    "ld1 { v25.b }[0], [x23]\n"
+    "55:"  // Oddments: Load (4, 0): Bit 2: End
+    "ushll v25.8h, v25.8b, #0x0\n"
+    "ldr x24, [x13, #0x98]\n"
+    "smlal v11.4s, v25.4h, v6.4h\n"
+    "smlal2 v10.4s, v25.8h, v6.8h\n"
+    "add x24, x24, x15\n"
+    "tbz x8, #2, 57f\n"
+    "ld1 { v29.s }[0], [x24], #0x4\n"
+    "tbz x8, #1, 56f\n"
+    "ld1 { v29.h }[2], [x24], #0x2\n"
+    "tbz x8, #0, 59f\n"
+    "ld1 { v29.b }[6], [x24]\n"
+    "b 59f\n"
+    "56:"  // Oddments: Load (2, 4): Bit 2: Bit 1: Unset
+    "tbz x8, #0, 59f\n"
+    "ld1 { v29.b }[4], [x24]\n"
+    "b 59f\n"
+    "57:"  // Oddments: Load (2, 4): Bit 2: Unset
+    "tbz x8, #1, 58f\n"
+    "ld1 { v29.h }[0], [x24], #0x2\n"
+    "tbz x8, #0, 59f\n"
+    "ld1 { v29.b }[2], [x24]\n"
+    "b 59f\n"
+    "58:"  // Oddments: Load (2, 4): Bit 2: Unset: Bit 1: Unset
+    "tbz x8, #0, 59f\n"
+    "ld1 { v29.b }[0], [x24]\n"
+    "59:"  // Oddments: Load (2, 4): Bit 2: End
+    "ushll v29.8h, v29.8b, #0x0\n"
+    "ldr x19, [x13, #0xa0]\n"
+    "smlal v13.4s, v29.4h, v8.4h\n"
+    "smlal2 v17.4s, v29.8h, v8.8h\n"
+    "smlal v23.4s, v29.4h, v2.4h\n"
+    "smlal2 v9.4s, v29.8h, v2.8h\n"
+    "add x19, x19, x15\n"
+    "tbz x8, #2, 61f\n"
+    "ld1 { v27.s }[0], [x19], #0x4\n"
+    "tbz x8, #1, 60f\n"
+    "ld1 { v27.h }[2], [x19], #0x2\n"
+    "tbz x8, #0, 63f\n"
+    "ld1 { v27.b }[6], [x19]\n"
+    "b 63f\n"
+    "60:"  // Oddments: Load (4, 1): Bit 2: Bit 1: Unset
+    "tbz x8, #0, 63f\n"
+    "ld1 { v27.b }[4], [x19]\n"
+    "b 63f\n"
+    "61:"  // Oddments: Load (4, 1): Bit 2: Unset
+    "tbz x8, #1, 62f\n"
+    "ld1 { v27.h }[0], [x19], #0x2\n"
+    "tbz x8, #0, 63f\n"
+    "ld1 { v27.b }[2], [x19]\n"
+    "b 63f\n"
+    "62:"  // Oddments: Load (4, 1): Bit 2: Unset: Bit 1: Unset
+    "tbz x8, #0, 63f\n"
+    "ld1 { v27.b }[0], [x19]\n"
+    "63:"  // Oddments: Load (4, 1): Bit 2: End
+    "ushll v27.8h, v27.8b, #0x0\n"
+    "ldr x22, [x13, #0xa8]\n"
+    "smlal v11.4s, v27.4h, v7.4h\n"
+    "smlal2 v10.4s, v27.8h, v7.8h\n"
+    "add x22, x22, x15\n"
+    "tbz x8, #2, 65f\n"
+    "ld1 { v24.s }[0], [x22], #0x4\n"
+    "tbz x8, #1, 64f\n"
+    "ld1 { v24.h }[2], [x22], #0x2\n"
+    "tbz x8, #0, 67f\n"
+    "ld1 { v24.b }[6], [x22]\n"
+    "b 67f\n"
+    "64:"  // Oddments: Load (3, 2): Bit 2: Bit 1: Unset
+    "tbz x8, #0, 67f\n"
+    "ld1 { v24.b }[4], [x22]\n"
+    "b 67f\n"
+    "65:"  // Oddments: Load (3, 2): Bit 2: Unset
+    "tbz x8, #1, 66f\n"
+    "ld1 { v24.h }[0], [x22], #0x2\n"
+    "tbz x8, #0, 67f\n"
+    "ld1 { v24.b }[2], [x22]\n"
+    "b 67f\n"
+    "66:"  // Oddments: Load (3, 2): Bit 2: Unset: Bit 1: Unset
+    "tbz x8, #0, 67f\n"
+    "ld1 { v24.b }[0], [x22]\n"
+    "67:"  // Oddments: Load (3, 2): Bit 2: End
+    "ushll v24.8h, v24.8b, #0x0\n"
+    "ldr x21, [x13, #0xb0]\n"
+    "smlal v11.4s, v24.4h, v5.4h\n"
+    "smlal2 v10.4s, v24.8h, v5.8h\n"
+    "smlal v23.4s, v24.4h, v3.4h\n"
+    "smlal2 v9.4s, v24.8h, v3.8h\n"
+    "add x21, x21, x15\n"
+    "tbz x8, #2, 69f\n"
+    "ld1 { v26.s }[0], [x21], #0x4\n"
+    "tbz x8, #1, 68f\n"
+    "ld1 { v26.h }[2], [x21], #0x2\n"
+    "tbz x8, #0, 71f\n"
+    "ld1 { v26.b }[6], [x21]\n"
+    "b 71f\n"
+    "68:"  // Oddments: Load (4, 3): Bit 2: Bit 1: Unset
+    "tbz x8, #0, 71f\n"
+    "ld1 { v26.b }[4], [x21]\n"
+    "b 71f\n"
+    "69:"  // Oddments: Load (4, 3): Bit 2: Unset
+    "tbz x8, #1, 70f\n"
+    "ld1 { v26.h }[0], [x21], #0x2\n"
+    "tbz x8, #0, 71f\n"
+    "ld1 { v26.b }[2], [x21]\n"
+    "b 71f\n"
+    "70:"  // Oddments: Load (4, 3): Bit 2: Unset: Bit 1: Unset
+    "tbz x8, #0, 71f\n"
+    "ld1 { v26.b }[0], [x21]\n"
+    "71:"  // Oddments: Load (4, 3): Bit 2: End
+    "ushll v26.8h, v26.8b, #0x0\n"
+    "ldr x20, [x13, #0xb8]\n"
+    "smlal v23.4s, v26.4h, v7.4h\n"
+    "smlal2 v9.4s, v26.8h, v7.8h\n"
+    "add x20, x20, x15\n"
+    "tbz x8, #2, 73f\n"
+    "ld1 { v25.s }[0], [x20], #0x4\n"
+    "tbz x8, #1, 72f\n"
+    "ld1 { v25.h }[2], [x20], #0x2\n"
+    "tbz x8, #0, 75f\n"
+    "ld1 { v25.b }[6], [x20]\n"
+    "b 75f\n"
+    "72:"  // Oddments: Load (4, 2): Bit 2: Bit 1: Unset
+    "tbz x8, #0, 75f\n"
+    "ld1 { v25.b }[4], [x20]\n"
+    "b 75f\n"
+    "73:"  // Oddments: Load (4, 2): Bit 2: Unset
+    "tbz x8, #1, 74f\n"
+    "ld1 { v25.h }[0], [x20], #0x2\n"
+    "tbz x8, #0, 75f\n"
+    "ld1 { v25.b }[2], [x20]\n"
+    "b 75f\n"
+    "74:"  // Oddments: Load (4, 2): Bit 2: Unset: Bit 1: Unset
+    "tbz x8, #0, 75f\n"
+    "ld1 { v25.b }[0], [x20]\n"
+    "75:"  // Oddments: Load (4, 2): Bit 2: End
+    "ushll v25.8h, v25.8b, #0x0\n"
+    "ldr x19, [x13, #0xc0]\n"
+    "smlal v11.4s, v25.4h, v8.4h\n"
+    "smlal2 v10.4s, v25.8h, v8.8h\n"
+    "smlal v23.4s, v25.4h, v6.4h\n"
+    "smlal2 v9.4s, v25.8h, v6.8h\n"
+    "add x19, x19, x15\n"
+    "tbz x8, #2, 77f\n"
+    "ld1 { v29.s }[0], [x19], #0x4\n"
+    "tbz x8, #1, 76f\n"
+    "ld1 { v29.h }[2], [x19], #0x2\n"
+    "tbz x8, #0, 79f\n"
+    "ld1 { v29.b }[6], [x19]\n"
+    "b 79f\n"
+    "76:"  // Oddments: Load (4, 4): Bit 2: Bit 1: Unset
+    "tbz x8, #0, 79f\n"
+    "ld1 { v29.b }[4], [x19]\n"
+    "b 79f\n"
+    "77:"  // Oddments: Load (4, 4): Bit 2: Unset
+    "tbz x8, #1, 78f\n"
+    "ld1 { v29.h }[0], [x19], #0x2\n"
+    "tbz x8, #0, 79f\n"
+    "ld1 { v29.b }[2], [x19]\n"
+    "b 79f\n"
+    "78:"  // Oddments: Load (4, 4): Bit 2: Unset: Bit 1: Unset
+    "tbz x8, #0, 79f\n"
+    "ld1 { v29.b }[0], [x19]\n"
+    "79:"  // Oddments: Load (4, 4): Bit 2: End
+    "ushll v29.8h, v29.8b, #0x0\n"
+    "smlal v23.4s, v29.4h, v8.4h\n"
+    "smlal2 v9.4s, v29.8h, v8.8h\n"
+    "tbz x8, #2, 81f\n"
+    "ld1 { v22.4s }, [x12], #0x10\n"
+    "ld1 { v31.4s }, [x11], #0x10\n"
+    "tbz x8, #1, 80f\n"
+    "ld1 { v19.d }[0], [x12], #0x8\n"
+    "ld1 { v30.d }[0], [x11], #0x8\n"
+    "tbz x8, #0, 83f\n"
+    "ld1 { v19.s }[2], [x12]\n"
+    "ld1 { v30.s }[2], [x11]\n"
+    "b 83f\n"
+    "80:"  // Oddments: Load requant params: Bit 2: Bit 1: Unset
+    "tbz x8, #0, 83f\n"
+    "ld1 { v19.s }[0], [x12]\n"
+    "ld1 { v30.s }[0], [x11]\n"
+    "b 83f\n"
+    "81:"  // Oddments: Load requant params: Bit 2: Unset
+    "tbz x8, #1, 82f\n"
+    "ld1 { v22.d }[0], [x12], #0x8\n"
+    "ld1 { v31.d }[0], [x11], #0x8\n"
+    "tbz x8, #0, 83f\n"
+    "ld1 { v22.s }[2], [x12]\n"
+    "ld1 { v31.s }[2], [x11]\n"
+    "b 83f\n"
+    "82:"  // Oddments: Load requant params: Bit 2: Unset: Bit 1: Unset
+    "tbz x8, #0, 83f\n"
+    "ld1 { v22.s }[0], [x12]\n"
+    "ld1 { v31.s }[0], [x11]\n"
+    "83:"  // Oddments: Load requant params: Bit 2: End
+    "sqdmulh v15.4s, v15.4s, v22.4s\n"
+    "sqdmulh v13.4s, v13.4s, v22.4s\n"
+    "add x10, x10, x14\n"
+    "add x9, x9, x14\n"
+    "sqdmulh v11.4s, v11.4s, v22.4s\n"
+    "sqdmulh v23.4s, v23.4s, v22.4s\n"
+    "add x28, x28, x14\n"
+    "add x27, x27, x14\n"
+    "and v4.16b, v15.16b, v31.16b\n"
+    "sqdmulh v18.4s, v18.4s, v19.4s\n"
+    "and v1.16b, v13.16b, v31.16b\n"
+    "sqdmulh v17.4s, v17.4s, v19.4s\n"
+    "and v22.16b, v11.16b, v31.16b\n"
+    "sqdmulh v10.4s, v10.4s, v19.4s\n"
+    "and v20.16b, v23.16b, v31.16b\n"
+    "sqdmulh v9.4s, v9.4s, v19.4s\n"
+    "sshr v4.4s, v4.4s, #0x1f\n"
+    "and v19.16b, v18.16b, v30.16b\n"
+    "sshr v1.4s, v1.4s, #0x1f\n"
+    "and v27.16b, v17.16b, v30.16b\n"
+    "sshr v22.4s, v22.4s, #0x1f\n"
+    "and v25.16b, v10.16b, v30.16b\n"
+    "sshr v20.4s, v20.4s, #0x1f\n"
+    "and v0.16b, v9.16b, v30.16b\n"
+    "sqadd v15.4s, v15.4s, v4.4s\n"
+    "sshr v19.4s, v19.4s, #0x1f\n"
+    "sqadd v13.4s, v13.4s, v1.4s\n"
+    "sshr v27.4s, v27.4s, #0x1f\n"
+    "sqadd v11.4s, v11.4s, v22.4s\n"
+    "sshr v25.4s, v25.4s, #0x1f\n"
+    "sqadd v23.4s, v23.4s, v20.4s\n"
+    "sshr v0.4s, v0.4s, #0x1f\n"
+    "srshl v15.4s, v15.4s, v31.4s\n"
+    "sqadd v18.4s, v18.4s, v19.4s\n"
+    "srshl v13.4s, v13.4s, v31.4s\n"
+    "sqadd v17.4s, v17.4s, v27.4s\n"
+    "srshl v11.4s, v11.4s, v31.4s\n"
+    "sqadd v10.4s, v10.4s, v25.4s\n"
+    "srshl v23.4s, v23.4s, v31.4s\n"
+    "sqadd v9.4s, v9.4s, v0.4s\n"
+    "srshl v18.4s, v18.4s, v30.4s\n"
+    "sqxtn v15.4h, v15.4s\n"
+    "srshl v17.4s, v17.4s, v30.4s\n"
+    "sqxtn v13.4h, v13.4s\n"
+    "srshl v10.4s, v10.4s, v30.4s\n"
+    "sqxtn v11.4h, v11.4s\n"
+    "srshl v9.4s, v9.4s, v30.4s\n"
+    "sqxtn v23.4h, v23.4s\n"
+    "sqxtn2 v15.8h, v18.4s\n"
+    "sqxtn2 v13.8h, v17.4s\n"
+    "sqxtn2 v11.8h, v10.4s\n"
+    "sqxtn2 v23.8h, v9.4s\n"
+    "sqadd v15.8h, v15.8h, v12.8h\n"
+    "sqadd v13.8h, v13.8h, v12.8h\n"
+    "sqadd v11.8h, v11.8h, v12.8h\n"
+    "sqadd v23.8h, v23.8h, v12.8h\n"
+    "smax v15.8h, v15.8h, v14.8h\n"
+    "smax v13.8h, v13.8h, v14.8h\n"
+    "smax v11.8h, v11.8h, v14.8h\n"
+    "smax v23.8h, v23.8h, v14.8h\n"
+    "smin v15.8h, v15.8h, v21.8h\n"
+    "smin v13.8h, v13.8h, v21.8h\n"
+    "smin v11.8h, v11.8h, v21.8h\n"
+    "smin v23.8h, v23.8h, v21.8h\n"
+    "uzp1 v15.16b, v15.16b, v15.16b\n"
+    "uzp1 v13.16b, v13.16b, v13.16b\n"
+    "uzp1 v11.16b, v11.16b, v11.16b\n"
+    "uzp1 v23.16b, v23.16b, v23.16b\n"
+    "tbz x8, #2, 85f\n"
+    "st1 { v15.s }[0], [x10], #0x4\n"
+    "st1 { v13.s }[0], [x9], #0x4\n"
+    "st1 { v11.s }[0], [x28], #0x4\n"
+    "st1 { v23.s }[0], [x27], #0x4\n"
+    "tbz x8, #1, 84f\n"
+    "st1 { v15.h }[2], [x10], #0x2\n"
+    "st1 { v13.h }[2], [x9], #0x2\n"
+    "st1 { v11.h }[2], [x28], #0x2\n"
+    "st1 { v23.h }[2], [x27], #0x2\n"
+    "tbz x8, #0, 87f\n"
+    "st1 { v15.b }[6], [x10], #0x1\n"
+    "st1 { v13.b }[6], [x9], #0x1\n"
+    "st1 { v11.b }[6], [x28], #0x1\n"
+    "st1 { v23.b }[6], [x27], #0x1\n"
+    "b 87f\n"
+    "84:"  // Oddments: Bit 2: Bit 1: Unset
+    "tbz x8, #0, 87f\n"
+    "st1 { v15.b }[4], [x10], #0x1\n"
+    "st1 { v13.b }[4], [x9], #0x1\n"
+    "st1 { v11.b }[4], [x28], #0x1\n"
+    "st1 { v23.b }[4], [x27], #0x1\n"
+    "b 87f\n"
+    "85:"  // Oddments: Bit 2: Unset
+    "tbz x8, #1, 86f\n"
+    "st1 { v15.h }[0], [x10], #0x2\n"
+    "st1 { v13.h }[0], [x9], #0x2\n"
+    "st1 { v11.h }[0], [x28], #0x2\n"
+    "st1 { v23.h }[0], [x27], #0x2\n"
+    "tbz x8, #0, 87f\n"
+    "st1 { v15.b }[2], [x10], #0x1\n"
+    "st1 { v13.b }[2], [x9], #0x1\n"
+    "st1 { v11.b }[2], [x28], #0x1\n"
+    "st1 { v23.b }[2], [x27], #0x1\n"
+    "b 87f\n"
+    "86:"  // Oddments: Bit 2: Unset: Bit 1: Unset
+    "tbz x8, #0, 87f\n"
+    "st1 { v15.b }[0], [x10], #0x1\n"
+    "st1 { v13.b }[0], [x9], #0x1\n"
+    "st1 { v11.b }[0], [x28], #0x1\n"
+    "st1 { v23.b }[0], [x27], #0x1\n"
+    "87:"  // Oddments: Bit 2: End
+    "88:"  // End
+    :
+    : [offsetof_Params_bias] "I" (offsetof(Params, bias)), [offsetof_Params_inptrs] "I" (offsetof(Params, inptrs)), [offsetof_Params_n_channels] "I" (offsetof(Params, n_channels)), [offsetof_Params_outptrs] "I" (offsetof(Params, outptrs)), [offsetof_Params_requant] "I" (offsetof(Params, requant)), [offsetof_Params_requant_muls] "I" (offsetof(Params, requant_muls)), [offsetof_Params_requant_shifts] "I" (offsetof(Params, requant_shifts)), [offsetof_Params_weights] "I" (offsetof(Params, weights)), [offsetof_Requantize32_b_offset] "I" (offsetof(arm_gemm::Requantize32, b_offset)), [offsetof_Requantize32_c_offset] "I" (offsetof(arm_gemm::Requantize32, c_offset)), [offsetof_Requantize32_maxval] "I" (offsetof(arm_gemm::Requantize32, maxval)), [offsetof_Requantize32_minval] "I" (offsetof(arm_gemm::Requantize32, minval)), [params] "r" (&params)
+    : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x8", "x9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17", "x19", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"
+  );
+}
+
+}  // namespace depthwise
+}  // namespace arm_conv
+
+#endif  // defined(__aarch64__)
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8qa_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8qa_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp
new file mode 100644
index 0000000..482d1af
--- /dev/null
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8qa_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp
@@ -0,0 +1,63 @@
+/*
+ * Copyright (c) 2022 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "src/core/NEON/kernels/arm_gemm/utils.hpp"
+
+#include "src/core/NEON/kernels/arm_conv/depthwise/interleaves/list.hpp"
+
+#include <cstdint>
+
+#pragma once
+
+#if defined(__aarch64__)
+
+namespace arm_conv {
+namespace depthwise {
+
+void a64_u8qa_nhwc_5x5_s1_output2x2_mla_depthfirst_impl(unsigned int, const uint8_t *const *, const uint8_t *, const int32_t *, const arm_gemm::Requantize32 &, const int32_t *, const int32_t *, uint8_t *const *);
+
+class a64_u8qa_nhwc_5x5_s1_output2x2_mla_depthfirst : public DepthwiseDepthfirstStrategy<uint8_t, uint8_t, uint8_t, int32_t>
+{
+  using Parent = DepthwiseDepthfirstStrategy<uint8_t, uint8_t, uint8_t, int32_t>;
+
+  public:
+  constexpr static unsigned int kernel_rows = 5;
+  constexpr static unsigned int kernel_cols = 5;
+
+  constexpr static unsigned int stride_rows = 1;
+  constexpr static unsigned int stride_cols = 1;
+
+  a64_u8qa_nhwc_5x5_s1_output2x2_mla_depthfirst(const CPUInfo *) : Parent(2, 2, 5, 5, 1, 1) {}
+
+  arm_gemm::VLType get_vl_type(void) const override { return arm_gemm::VLType::None; }
+
+  Parent::KernelType kernel = a64_u8qa_nhwc_5x5_s1_output2x2_mla_depthfirst_impl;
+  Parent::KernelType get_kernel(void) const override { return kernel; }
+  unsigned int get_accumulator_depth_vl(void) const override { return 2; }
+};
+
+}  // namespace depthwise
+}  // namespace arm_conv
+
+#endif  // defined(__aarch64__)
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8qa_nhwc_5x5_s1_output2x2_mla_depthfirst/generic.cpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8qa_nhwc_5x5_s1_output2x2_mla_depthfirst/generic.cpp
new file mode 100644
index 0000000..7f1fd1d
--- /dev/null
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8qa_nhwc_5x5_s1_output2x2_mla_depthfirst/generic.cpp
@@ -0,0 +1,2185 @@
+/*
+ * Copyright (c) 2022 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "arm_gemm.hpp"
+
+#include <cstddef>
+#include <cstdint>
+
+#if defined(__aarch64__)
+
+namespace arm_conv {
+namespace depthwise {
+
+void a64_u8qa_nhwc_5x5_s1_output2x2_mla_depthfirst_impl(
+  const unsigned int n_channels,
+  const uint8_t *const *const inptrs,
+  const uint8_t *const weights,
+  const int32_t *const bias,
+  const arm_gemm::Requantize32 &qp,
+  const int32_t *const requant_muls,
+  const int32_t *const requant_shifts,
+  uint8_t *const *const outptrs
+)
+{
+  struct Params
+  {
+    long unsigned int n_channels;
+    const void *weights;
+    const int32_t *bias;
+    const arm_gemm::Requantize32 *requant;
+    const int32_t *const requant_muls;
+    const int32_t *const requant_shifts;
+    uint8_t *const *const outptrs;
+    const uint8_t *inptrs[36];
+
+    Params(
+      long unsigned int n_channels,
+      const uint8_t *const *inptrs_raw,
+      const void *const weights,
+      const int32_t *const bias,
+      const arm_gemm::Requantize32 &qp,
+      const int32_t *const requant_muls,
+      const int32_t *const requant_shifts,
+      uint8_t *const *outptrs
+    ) : n_channels(n_channels), weights(weights), bias(bias),
+        requant(&qp), requant_muls(requant_muls),
+        requant_shifts(requant_shifts), outptrs(outptrs)
+    {
+      inptrs[0] = inptrs_raw[0];
+      inptrs[1] = inptrs_raw[1];
+      inptrs[2] = inptrs_raw[6];
+      inptrs[3] = inptrs_raw[7];
+      inptrs[4] = inptrs_raw[2];
+      inptrs[5] = inptrs_raw[8];
+      inptrs[6] = inptrs_raw[3];
+      inptrs[7] = inptrs_raw[4];
+      inptrs[8] = inptrs_raw[11];
+      inptrs[9] = inptrs_raw[12];
+      inptrs[10] = inptrs_raw[9];
+      inptrs[11] = inptrs_raw[10];
+      inptrs[12] = inptrs_raw[5];
+      inptrs[13] = inptrs_raw[13];
+      inptrs[14] = inptrs_raw[14];
+      inptrs[15] = inptrs_raw[15];
+      inptrs[16] = inptrs_raw[16];
+      inptrs[17] = inptrs_raw[17];
+      inptrs[18] = inptrs_raw[18];
+      inptrs[19] = inptrs_raw[19];
+      inptrs[20] = inptrs_raw[20];
+      inptrs[21] = inptrs_raw[21];
+      inptrs[22] = inptrs_raw[22];
+      inptrs[23] = inptrs_raw[23];
+      inptrs[24] = inptrs_raw[24];
+      inptrs[25] = inptrs_raw[25];
+      inptrs[26] = inptrs_raw[26];
+      inptrs[27] = inptrs_raw[27];
+      inptrs[28] = inptrs_raw[28];
+      inptrs[29] = inptrs_raw[29];
+      inptrs[30] = inptrs_raw[30];
+      inptrs[31] = inptrs_raw[31];
+      inptrs[32] = inptrs_raw[32];
+      inptrs[33] = inptrs_raw[33];
+      inptrs[34] = inptrs_raw[34];
+      inptrs[35] = inptrs_raw[35];
+
+    }
+  };
+
+  const Params params(n_channels, inptrs, weights, bias, qp,
+                      requant_muls, requant_shifts, outptrs);
+
+  __asm__ __volatile__(
+    "ldr x16, [%x[params], %[offsetof_Params_requant]]\n"
+    "ldr x4, [%x[params], %[offsetof_Params_n_channels]]\n"
+    "add x9, x16, %[offsetof_Requantize32_b_offset]\n"
+    "add x19, x16, %[offsetof_Requantize32_c_offset]\n"
+    "ldr x10, [%x[params], %[offsetof_Params_outptrs]]\n"
+    "add x24, x16, %[offsetof_Requantize32_minval]\n"
+    "add x2, x16, %[offsetof_Requantize32_maxval]\n"
+    "ldr x8, [%x[params], %[offsetof_Params_weights]]\n"
+    "ld1r { v15.16b }, [x9]\n"
+    "ld1r { v16.8h }, [x19]\n"
+    "lsr x3, x4, #0x3\n"
+    "mov x1, #0x0\n"
+    "ld1r { v12.8h }, [x24]\n"
+    "ld1r { v13.8h }, [x2]\n"
+    "mov x2, #0x0\n"
+    "add x0, %x[params], %[offsetof_Params_inptrs]\n"
+    "ldr x5, [%x[params], %[offsetof_Params_requant_muls]]\n"
+    "ldr x6, [%x[params], %[offsetof_Params_requant_shifts]]\n"
+    "ldp x21, x15, [x10, #0x0]\n"
+    "ldp x17, x16, [x10, #0x10]\n"
+    "cbz x3, 3f\n"
+    "ldr x19, [%x[params], %[offsetof_Params_bias]]\n"
+    "ldr q11, [x19, #0x0]\n"
+    "subs x3, x3, #0x1\n"
+    "mov v14.16b, v11.16b\n"
+    "ldr q21, [x19, #0x10]\n"
+    "add x19, x19, #0x20\n"
+    "str x19, [%x[params], %[offsetof_Params_bias]]\n"
+    "ldr d0, [x8, #0x0]\n"
+    "ldr d1, [x8, #0x8]\n"
+    "ldr d2, [x8, #0x10]\n"
+    "mov v10.16b, v21.16b\n"
+    "mov v9.16b, v11.16b\n"
+    "ldr d3, [x8, #0x18]\n"
+    "ldr d4, [x8, #0x20]\n"
+    "mov v8.16b, v21.16b\n"
+    "mov v7.16b, v11.16b\n"
+    "ldp x28, x27, [x0, #0x0]\n"
+    "ldp x10, x26, [x0, #0x10]\n"
+    "mov v6.16b, v21.16b\n"
+    "usubl v0.8h, v0.8b, v15.8b\n"
+    "ldp x24, x23, [x0, #0x20]\n"
+    "ldp x22, x25, [x0, #0x30]\n"
+    "usubl v1.8h, v1.8b, v15.8b\n"
+    "usubl v2.8h, v2.8b, v15.8b\n"
+    "ldp x20, x19, [x0, #0x40]\n"
+    "ldr d31, [x28, x1]\n"
+    "usubl v3.8h, v3.8b, v15.8b\n"
+    "usubl v4.8h, v4.8b, v15.8b\n"
+    "ldr d30, [x27, x1]\n"
+    "ldr d29, [x10, x1]\n"
+    "ushll v31.8h, v31.8b, #0x0\n"
+    "ushll v30.8h, v30.8b, #0x0\n"
+    "ldr d28, [x26, x1]\n"
+    "ldr d27, [x24, x1]\n"
+    "ushll v29.8h, v29.8b, #0x0\n"
+    "ushll v28.8h, v28.8b, #0x0\n"
+    "ldr d23, [x23, x1]\n"
+    "ldr d25, [x22, x1]\n"
+    "ushll v27.8h, v27.8b, #0x0\n"
+    "ushll v23.8h, v23.8b, #0x0\n"
+    "ldr d24, [x25, x1]\n"
+    "ldr d26, [x20, x1]\n"
+    "ushll v25.8h, v25.8b, #0x0\n"
+    "ushll v24.8h, v24.8b, #0x0\n"
+    "ldr d22, [x19, x1]\n"
+    "ushll v26.8h, v26.8b, #0x0\n"
+    "ushll v22.8h, v22.8b, #0x0\n"
+    "beq 2f\n"
+    "1:"  // Loop
+    "smlal v11.4s, v31.4h, v0.4h\n"
+    "smlal2 v21.4s, v31.8h, v0.8h\n"
+    "ldr x19, [x0, #0x50]\n"
+    "ldr d31, [x19, x1]\n"
+    "smlal v14.4s, v30.4h, v0.4h\n"
+    "smlal v9.4s, v29.4h, v0.4h\n"
+    "ldr x20, [x0, #0x58]\n"
+    "ushll v31.8h, v31.8b, #0x0\n"
+    "smlal v7.4s, v28.4h, v0.4h\n"
+    "smlal2 v10.4s, v30.8h, v0.8h\n"
+    "ldr x19, [x0, #0x60]\n"
+    "ldr x24, [x0, #0x68]\n"
+    "smlal2 v8.4s, v29.8h, v0.8h\n"
+    "smlal v11.4s, v30.4h, v1.4h\n"
+    "ldr x23, [x0, #0x70]\n"
+    "ldr x26, [x0, #0x78]\n"
+    "smlal2 v21.4s, v30.8h, v1.8h\n"
+    "smlal2 v6.4s, v28.8h, v0.8h\n"
+    "ldr d30, [x20, x1]\n"
+    "ushll v30.8h, v30.8b, #0x0\n"
+    "smlal v14.4s, v27.4h, v1.4h\n"
+    "smlal v9.4s, v28.4h, v1.4h\n"
+    "ldr d0, [x8, #0x28]\n"
+    "usubl v0.8h, v0.8b, v15.8b\n"
+    "smlal v7.4s, v23.4h, v1.4h\n"
+    "smlal2 v10.4s, v27.8h, v1.8h\n"
+    "ldr x7, [x0, #0x80]\n"
+    "ldr x22, [x0, #0x88]\n"
+    "smlal2 v8.4s, v28.8h, v1.8h\n"
+    "smlal v11.4s, v27.4h, v2.4h\n"
+    "ldr x20, [x0, #0x90]\n"
+    "ldr x14, [x0, #0x98]\n"
+    "smlal2 v21.4s, v27.8h, v2.8h\n"
+    "smlal2 v6.4s, v23.8h, v1.8h\n"
+    "ldr d27, [x19, x1]\n"
+    "ushll v27.8h, v27.8b, #0x0\n"
+    "smlal v14.4s, v25.4h, v2.4h\n"
+    "smlal v9.4s, v23.4h, v2.4h\n"
+    "ldr d1, [x8, #0x30]\n"
+    "usubl v1.8h, v1.8b, v15.8b\n"
+    "smlal v7.4s, v31.4h, v2.4h\n"
+    "smlal2 v10.4s, v25.8h, v2.8h\n"
+    "ldr x19, [x0, #0xa0]\n"
+    "ldr x13, [x0, #0xa8]\n"
+    "smlal2 v8.4s, v23.8h, v2.8h\n"
+    "smlal v11.4s, v25.4h, v3.4h\n"
+    "ldr x12, [x0, #0xb0]\n"
+    "ldr x11, [x0, #0xb8]\n"
+    "smlal2 v21.4s, v25.8h, v3.8h\n"
+    "smlal2 v6.4s, v31.8h, v2.8h\n"
+    "ldr d25, [x24, x1]\n"
+    "ushll v25.8h, v25.8b, #0x0\n"
+    "smlal v14.4s, v24.4h, v3.4h\n"
+    "smlal v9.4s, v31.4h, v3.4h\n"
+    "ldr d2, [x8, #0x38]\n"
+    "usubl v2.8h, v2.8b, v15.8b\n"
+    "smlal v7.4s, v30.4h, v3.4h\n"
+    "smlal2 v10.4s, v24.8h, v3.8h\n"
+    "ldr x10, [x0, #0xc0]\n"
+    "ldr x9, [x0, #0xc8]\n"
+    "smlal2 v8.4s, v31.8h, v3.8h\n"
+    "smlal v11.4s, v24.4h, v4.4h\n"
+    "ldr x28, [x0, #0xd0]\n"
+    "ldr x27, [x0, #0xd8]\n"
+    "smlal2 v21.4s, v24.8h, v4.8h\n"
+    "smlal2 v6.4s, v30.8h, v3.8h\n"
+    "ldr d24, [x23, x1]\n"
+    "ushll v24.8h, v24.8b, #0x0\n"
+    "smlal v14.4s, v27.4h, v4.4h\n"
+    "smlal v9.4s, v30.4h, v4.4h\n"
+    "ldr d3, [x8, #0x40]\n"
+    "usubl v3.8h, v3.8b, v15.8b\n"
+    "smlal v7.4s, v26.4h, v4.4h\n"
+    "smlal2 v10.4s, v27.8h, v4.8h\n"
+    "ldr d27, [x26, x1]\n"
+    "ushll v27.8h, v27.8b, #0x0\n"
+    "smlal2 v8.4s, v30.8h, v4.8h\n"
+    "smlal v11.4s, v29.4h, v0.4h\n"
+    "ldr x26, [x0, #0xe0]\n"
+    "ldr x25, [x0, #0xe8]\n"
+    "smlal2 v21.4s, v29.8h, v0.8h\n"
+    "smlal2 v6.4s, v26.8h, v4.8h\n"
+    "ldr d4, [x8, #0x48]\n"
+    "usubl v4.8h, v4.8b, v15.8b\n"
+    "smlal v14.4s, v28.4h, v0.4h\n"
+    "smlal v9.4s, v22.4h, v0.4h\n"
+    "ldr x24, [x0, #0xf0]\n"
+    "ldr q17, [x5, #0x0]\n"
+    "smlal v7.4s, v25.4h, v0.4h\n"
+    "smlal2 v10.4s, v28.8h, v0.8h\n"
+    "ldr q5, [x6, #0x0]\n"
+    "ldr q18, [x5, #0x10]\n"
+    "smlal2 v8.4s, v22.8h, v0.8h\n"
+    "smlal v11.4s, v28.4h, v1.4h\n"
+    "ldr q29, [x6, #0x10]\n"
+    "subs x3, x3, #0x1\n"
+    "smlal2 v21.4s, v28.8h, v1.8h\n"
+    "smlal2 v6.4s, v25.8h, v0.8h\n"
+    "ldr d28, [x22, x1]\n"
+    "ldr d0, [x8, #0x50]\n"
+    "smlal v14.4s, v23.4h, v1.4h\n"
+    "smlal v9.4s, v25.4h, v1.4h\n"
+    "ushll v28.8h, v28.8b, #0x0\n"
+    "ldr x23, [x0, #0xf8]\n"
+    "smlal v7.4s, v24.4h, v1.4h\n"
+    "smlal2 v10.4s, v23.8h, v1.8h\n"
+    "usubl v0.8h, v0.8b, v15.8b\n"
+    "add x5, x5, #0x20\n"
+    "smlal2 v8.4s, v25.8h, v1.8h\n"
+    "smlal v11.4s, v23.4h, v2.4h\n"
+    "add x6, x6, #0x20\n"
+    "smlal2 v21.4s, v23.8h, v2.8h\n"
+    "ldr d23, [x7, x1]\n"
+    "smlal2 v6.4s, v24.8h, v1.8h\n"
+    "ushll v23.8h, v23.8b, #0x0\n"
+    "smlal v14.4s, v31.4h, v2.4h\n"
+    "smlal v9.4s, v24.4h, v2.4h\n"
+    "ldr d1, [x8, #0x58]\n"
+    "usubl v1.8h, v1.8b, v15.8b\n"
+    "smlal v7.4s, v27.4h, v2.4h\n"
+    "smlal2 v10.4s, v31.8h, v2.8h\n"
+    "ldr x22, [x0, #0x100]\n"
+    "smlal2 v8.4s, v24.8h, v2.8h\n"
+    "smlal v11.4s, v31.4h, v3.4h\n"
+    "smlal2 v21.4s, v31.8h, v3.8h\n"
+    "smlal2 v6.4s, v27.8h, v2.8h\n"
+    "ldr d31, [x20, x1]\n"
+    "ushll v31.8h, v31.8b, #0x0\n"
+    "smlal v14.4s, v30.4h, v3.4h\n"
+    "smlal v9.4s, v27.4h, v3.4h\n"
+    "ldr d2, [x8, #0x60]\n"
+    "usubl v2.8h, v2.8b, v15.8b\n"
+    "smlal v7.4s, v23.4h, v3.4h\n"
+    "smlal2 v10.4s, v30.8h, v3.8h\n"
+    "ldr x7, [x0, #0x108]\n"
+    "smlal2 v8.4s, v27.8h, v3.8h\n"
+    "smlal v11.4s, v30.4h, v4.4h\n"
+    "smlal2 v21.4s, v30.8h, v4.8h\n"
+    "ldr d30, [x14, x1]\n"
+    "smlal2 v6.4s, v23.8h, v3.8h\n"
+    "ushll v30.8h, v30.8b, #0x0\n"
+    "smlal v14.4s, v26.4h, v4.4h\n"
+    "smlal v9.4s, v23.4h, v4.4h\n"
+    "ldr d3, [x8, #0x68]\n"
+    "usubl v3.8h, v3.8b, v15.8b\n"
+    "smlal v7.4s, v28.4h, v4.4h\n"
+    "smlal2 v10.4s, v26.8h, v4.8h\n"
+    "ldr d26, [x19, x1]\n"
+    "ushll v26.8h, v26.8b, #0x0\n"
+    "smlal2 v8.4s, v23.8h, v4.8h\n"
+    "smlal v11.4s, v22.4h, v0.4h\n"
+    "ldr x20, [x0, #0x110]\n"
+    "ldr x19, [x0, #0x118]\n"
+    "smlal2 v21.4s, v22.8h, v0.8h\n"
+    "smlal2 v6.4s, v28.8h, v4.8h\n"
+    "ldr d4, [x8, #0x70]\n"
+    "ldr d22, [x11, x1]\n"
+    "smlal v14.4s, v25.4h, v0.4h\n"
+    "smlal v9.4s, v31.4h, v0.4h\n"
+    "usubl v4.8h, v4.8b, v15.8b\n"
+    "smlal v7.4s, v30.4h, v0.4h\n"
+    "smlal2 v10.4s, v25.8h, v0.8h\n"
+    "ushll v22.8h, v22.8b, #0x0\n"
+    "smlal2 v8.4s, v31.8h, v0.8h\n"
+    "smlal v11.4s, v25.4h, v1.4h\n"
+    "smlal2 v21.4s, v25.8h, v1.8h\n"
+    "ldr d25, [x13, x1]\n"
+    "smlal2 v6.4s, v30.8h, v0.8h\n"
+    "ushll v25.8h, v25.8b, #0x0\n"
+    "smlal v14.4s, v24.4h, v1.4h\n"
+    "smlal v9.4s, v30.4h, v1.4h\n"
+    "ldr d0, [x8, #0x78]\n"
+    "usubl v0.8h, v0.8b, v15.8b\n"
+    "smlal v7.4s, v26.4h, v1.4h\n"
+    "smlal2 v10.4s, v24.8h, v1.8h\n"
+    "smlal2 v8.4s, v30.8h, v1.8h\n"
+    "smlal v11.4s, v24.4h, v2.4h\n"
+    "smlal2 v21.4s, v24.8h, v2.8h\n"
+    "ldr d24, [x12, x1]\n"
+    "smlal2 v6.4s, v26.8h, v1.8h\n"
+    "ushll v24.8h, v24.8b, #0x0\n"
+    "smlal v14.4s, v27.4h, v2.4h\n"
+    "smlal v9.4s, v26.4h, v2.4h\n"
+    "ldr d1, [x8, #0x80]\n"
+    "usubl v1.8h, v1.8b, v15.8b\n"
+    "smlal v7.4s, v25.4h, v2.4h\n"
+    "smlal2 v10.4s, v27.8h, v2.8h\n"
+    "smlal2 v8.4s, v26.8h, v2.8h\n"
+    "smlal v11.4s, v27.4h, v3.4h\n"
+    "smlal2 v21.4s, v27.8h, v3.8h\n"
+    "smlal2 v6.4s, v25.8h, v2.8h\n"
+    "ldr d27, [x10, x1]\n"
+    "ushll v27.8h, v27.8b, #0x0\n"
+    "smlal v14.4s, v23.4h, v3.4h\n"
+    "smlal v9.4s, v25.4h, v3.4h\n"
+    "ldr d2, [x8, #0x88]\n"
+    "usubl v2.8h, v2.8b, v15.8b\n"
+    "smlal v7.4s, v24.4h, v3.4h\n"
+    "smlal2 v10.4s, v23.8h, v3.8h\n"
+    "smlal2 v8.4s, v25.8h, v3.8h\n"
+    "smlal v11.4s, v23.4h, v4.4h\n"
+    "smlal2 v21.4s, v23.8h, v4.8h\n"
+    "ldr d23, [x9, x1]\n"
+    "smlal2 v6.4s, v24.8h, v3.8h\n"
+    "ushll v23.8h, v23.8b, #0x0\n"
+    "smlal v14.4s, v28.4h, v4.4h\n"
+    "smlal v9.4s, v24.4h, v4.4h\n"
+    "ldr d3, [x8, #0x90]\n"
+    "usubl v3.8h, v3.8b, v15.8b\n"
+    "smlal v7.4s, v22.4h, v4.4h\n"
+    "smlal2 v10.4s, v28.8h, v4.8h\n"
+    "ldr d28, [x26, x1]\n"
+    "ushll v28.8h, v28.8b, #0x0\n"
+    "smlal2 v8.4s, v24.8h, v4.8h\n"
+    "smlal v11.4s, v31.4h, v0.4h\n"
+    "smlal2 v21.4s, v31.8h, v0.8h\n"
+    "ldr d31, [x28, x1]\n"
+    "smlal2 v6.4s, v22.8h, v4.8h\n"
+    "ushll v31.8h, v31.8b, #0x0\n"
+    "smlal v14.4s, v30.4h, v0.4h\n"
+    "smlal v9.4s, v27.4h, v0.4h\n"
+    "ldr d4, [x8, #0x98]\n"
+    "usubl v4.8h, v4.8b, v15.8b\n"
+    "smlal v7.4s, v23.4h, v0.4h\n"
+    "smlal2 v10.4s, v30.8h, v0.8h\n"
+    "smlal2 v8.4s, v27.8h, v0.8h\n"
+    "smlal v11.4s, v30.4h, v1.4h\n"
+    "smlal2 v21.4s, v30.8h, v1.8h\n"
+    "ldr d30, [x27, x1]\n"
+    "smlal2 v6.4s, v23.8h, v0.8h\n"
+    "ushll v30.8h, v30.8b, #0x0\n"
+    "smlal v14.4s, v26.4h, v1.4h\n"
+    "smlal v9.4s, v23.4h, v1.4h\n"
+    "ldr d0, [x8, #0xa0]\n"
+    "usubl v0.8h, v0.8b, v15.8b\n"
+    "smlal v7.4s, v31.4h, v1.4h\n"
+    "smlal2 v10.4s, v26.8h, v1.8h\n"
+    "smlal2 v8.4s, v23.8h, v1.8h\n"
+    "smlal v11.4s, v26.4h, v2.4h\n"
+    "smlal2 v21.4s, v26.8h, v2.8h\n"
+    "smlal2 v6.4s, v31.8h, v1.8h\n"
+    "ldr d26, [x25, x1]\n"
+    "ushll v26.8h, v26.8b, #0x0\n"
+    "smlal v14.4s, v25.4h, v2.4h\n"
+    "smlal v9.4s, v31.4h, v2.4h\n"
+    "ldr d1, [x8, #0xa8]\n"
+    "usubl v1.8h, v1.8b, v15.8b\n"
+    "smlal v7.4s, v30.4h, v2.4h\n"
+    "smlal2 v10.4s, v25.8h, v2.8h\n"
+    "smlal2 v8.4s, v31.8h, v2.8h\n"
+    "smlal v11.4s, v25.4h, v3.4h\n"
+    "smlal2 v21.4s, v25.8h, v3.8h\n"
+    "smlal2 v6.4s, v30.8h, v2.8h\n"
+    "ldr d25, [x24, x1]\n"
+    "ushll v25.8h, v25.8b, #0x0\n"
+    "smlal v14.4s, v24.4h, v3.4h\n"
+    "smlal v9.4s, v30.4h, v3.4h\n"
+    "ldr d2, [x8, #0xb0]\n"
+    "usubl v2.8h, v2.8b, v15.8b\n"
+    "smlal v7.4s, v28.4h, v3.4h\n"
+    "smlal2 v10.4s, v24.8h, v3.8h\n"
+    "smlal2 v8.4s, v30.8h, v3.8h\n"
+    "smlal v11.4s, v24.4h, v4.4h\n"
+    "smlal2 v21.4s, v24.8h, v4.8h\n"
+    "ldr d24, [x23, x1]\n"
+    "smlal2 v6.4s, v28.8h, v3.8h\n"
+    "ushll v24.8h, v24.8b, #0x0\n"
+    "smlal v14.4s, v22.4h, v4.4h\n"
+    "smlal v9.4s, v28.4h, v4.4h\n"
+    "ldr d3, [x8, #0xb8]\n"
+    "usubl v3.8h, v3.8b, v15.8b\n"
+    "smlal v7.4s, v26.4h, v4.4h\n"
+    "smlal2 v8.4s, v28.8h, v4.8h\n"
+    "smlal v11.4s, v27.4h, v0.4h\n"
+    "smlal2 v21.4s, v27.8h, v0.8h\n"
+    "ldr d27, [x22, x1]\n"
+    "ushll v27.8h, v27.8b, #0x0\n"
+    "smlal2 v10.4s, v22.8h, v4.8h\n"
+    "smlal2 v6.4s, v26.8h, v4.8h\n"
+    "ldr d4, [x8, #0xc0]\n"
+    "usubl v4.8h, v4.8b, v15.8b\n"
+    "smlal v14.4s, v23.4h, v0.4h\n"
+    "smlal v9.4s, v25.4h, v0.4h\n"
+    "add x8, x8, #0xc8\n"
+    "smlal v7.4s, v24.4h, v0.4h\n"
+    "smlal2 v8.4s, v25.8h, v0.8h\n"
+    "ldr d25, [x7, x1]\n"
+    "ushll v25.8h, v25.8b, #0x0\n"
+    "smlal2 v10.4s, v23.8h, v0.8h\n"
+    "smlal2 v6.4s, v24.8h, v0.8h\n"
+    "smlal v11.4s, v23.4h, v1.4h\n"
+    "smlal v14.4s, v31.4h, v1.4h\n"
+    "smlal v9.4s, v24.4h, v1.4h\n"
+    "smlal v7.4s, v27.4h, v1.4h\n"
+    "smlal2 v8.4s, v24.8h, v1.8h\n"
+    "ldr d24, [x20, x1]\n"
+    "smlal2 v21.4s, v23.8h, v1.8h\n"
+    "ushll v24.8h, v24.8b, #0x0\n"
+    "smlal2 v10.4s, v31.8h, v1.8h\n"
+    "smlal2 v6.4s, v27.8h, v1.8h\n"
+    "smlal v11.4s, v31.4h, v2.4h\n"
+    "smlal v14.4s, v30.4h, v2.4h\n"
+    "smlal v9.4s, v27.4h, v2.4h\n"
+    "smlal v7.4s, v25.4h, v2.4h\n"
+    "smlal2 v8.4s, v27.8h, v2.8h\n"
+    "ldr d27, [x19, x1]\n"
+    "smlal2 v21.4s, v31.8h, v2.8h\n"
+    "ushll v27.8h, v27.8b, #0x0\n"
+    "smlal2 v10.4s, v30.8h, v2.8h\n"
+    "smlal2 v6.4s, v25.8h, v2.8h\n"
+    "add x1, x1, #0x8\n"
+    "smlal v11.4s, v30.4h, v3.4h\n"
+    "smlal v14.4s, v28.4h, v3.4h\n"
+    "smlal v9.4s, v25.4h, v3.4h\n"
+    "smlal v7.4s, v24.4h, v3.4h\n"
+    "smlal2 v21.4s, v30.8h, v3.8h\n"
+    "smlal2 v10.4s, v28.8h, v3.8h\n"
+    "smlal2 v8.4s, v25.8h, v3.8h\n"
+    "smlal2 v6.4s, v24.8h, v3.8h\n"
+    "smlal v11.4s, v28.4h, v4.4h\n"
+    "smlal v14.4s, v26.4h, v4.4h\n"
+    "sqdmulh v11.4s, v11.4s, v17.4s\n"
+    "smlal v9.4s, v24.4h, v4.4h\n"
+    "smlal v7.4s, v27.4h, v4.4h\n"
+    "sqdmulh v14.4s, v14.4s, v17.4s\n"
+    "smlal2 v21.4s, v28.8h, v4.8h\n"
+    "smlal2 v10.4s, v26.8h, v4.8h\n"
+    "sqdmulh v9.4s, v9.4s, v17.4s\n"
+    "smlal2 v8.4s, v24.8h, v4.8h\n"
+    "smlal2 v6.4s, v27.8h, v4.8h\n"
+    "sqdmulh v7.4s, v7.4s, v17.4s\n"
+    "and v23.16b, v11.16b, v5.16b\n"
+    "sqdmulh v21.4s, v21.4s, v18.4s\n"
+    "and v22.16b, v14.16b, v5.16b\n"
+    "sqdmulh v10.4s, v10.4s, v18.4s\n"
+    "and v17.16b, v9.16b, v5.16b\n"
+    "sqdmulh v8.4s, v8.4s, v18.4s\n"
+    "and v20.16b, v7.16b, v5.16b\n"
+    "sqdmulh v6.4s, v6.4s, v18.4s\n"
+    "sshr v23.4s, v23.4s, #0x1f\n"
+    "and v19.16b, v21.16b, v29.16b\n"
+    "sshr v22.4s, v22.4s, #0x1f\n"
+    "and v18.16b, v10.16b, v29.16b\n"
+    "sshr v17.4s, v17.4s, #0x1f\n"
+    "and v26.16b, v8.16b, v29.16b\n"
+    "sshr v20.4s, v20.4s, #0x1f\n"
+    "and v4.16b, v6.16b, v29.16b\n"
+    "sqadd v11.4s, v11.4s, v23.4s\n"
+    "sshr v19.4s, v19.4s, #0x1f\n"
+    "sqadd v14.4s, v14.4s, v22.4s\n"
+    "sshr v18.4s, v18.4s, #0x1f\n"
+    "sqadd v9.4s, v9.4s, v17.4s\n"
+    "sshr v26.4s, v26.4s, #0x1f\n"
+    "sqadd v7.4s, v7.4s, v20.4s\n"
+    "sshr v4.4s, v4.4s, #0x1f\n"
+    "srshl v11.4s, v11.4s, v5.4s\n"
+    "sqadd v21.4s, v21.4s, v19.4s\n"
+    "srshl v14.4s, v14.4s, v5.4s\n"
+    "sqadd v10.4s, v10.4s, v18.4s\n"
+    "srshl v9.4s, v9.4s, v5.4s\n"
+    "sqadd v8.4s, v8.4s, v26.4s\n"
+    "srshl v7.4s, v7.4s, v5.4s\n"
+    "sqadd v6.4s, v6.4s, v4.4s\n"
+    "srshl v21.4s, v21.4s, v29.4s\n"
+    "sqxtn v11.4h, v11.4s\n"
+    "srshl v10.4s, v10.4s, v29.4s\n"
+    "sqxtn v14.4h, v14.4s\n"
+    "srshl v8.4s, v8.4s, v29.4s\n"
+    "sqxtn v9.4h, v9.4s\n"
+    "srshl v6.4s, v6.4s, v29.4s\n"
+    "sqxtn v7.4h, v7.4s\n"
+    "sqxtn2 v11.8h, v21.4s\n"
+    "sqxtn2 v14.8h, v10.4s\n"
+    "sqxtn2 v9.8h, v8.4s\n"
+    "sqxtn2 v7.8h, v6.4s\n"
+    "sqadd v11.8h, v11.8h, v16.8h\n"
+    "sqadd v14.8h, v14.8h, v16.8h\n"
+    "sqadd v9.8h, v9.8h, v16.8h\n"
+    "sqadd v7.8h, v7.8h, v16.8h\n"
+    "smax v11.8h, v11.8h, v12.8h\n"
+    "smax v14.8h, v14.8h, v12.8h\n"
+    "smax v9.8h, v9.8h, v12.8h\n"
+    "smax v7.8h, v7.8h, v12.8h\n"
+    "smin v11.8h, v11.8h, v13.8h\n"
+    "smin v14.8h, v14.8h, v13.8h\n"
+    "smin v9.8h, v9.8h, v13.8h\n"
+    "smin v7.8h, v7.8h, v13.8h\n"
+    "uzp1 v11.16b, v11.16b, v11.16b\n"
+    "uzp1 v14.16b, v14.16b, v14.16b\n"
+    "str d11, [x21, x2]\n"
+    "uzp1 v9.16b, v9.16b, v9.16b\n"
+    "uzp1 v7.16b, v7.16b, v7.16b\n"
+    "str d14, [x15, x2]\n"
+    "str d9, [x17, x2]\n"
+    "str d7, [x16, x2]\n"
+    "ldr x19, [%x[params], %[offsetof_Params_bias]]\n"
+    "ldr q11, [x19, #0x0]\n"
+    "add x2, x2, #0x8\n"
+    "ldr q21, [x19, #0x10]\n"
+    "add x19, x19, #0x20\n"
+    "str x19, [%x[params], %[offsetof_Params_bias]]\n"
+    "ldr d0, [x8, #0x0]\n"
+    "ldr d1, [x8, #0x8]\n"
+    "ldr d2, [x8, #0x10]\n"
+    "mov v14.16b, v11.16b\n"
+    "mov v10.16b, v21.16b\n"
+    "ldr d3, [x8, #0x18]\n"
+    "ldr d4, [x8, #0x20]\n"
+    "mov v9.16b, v11.16b\n"
+    "mov v8.16b, v21.16b\n"
+    "ldp x28, x27, [x0, #0x0]\n"
+    "ldp x10, x26, [x0, #0x10]\n"
+    "mov v7.16b, v11.16b\n"
+    "mov v6.16b, v21.16b\n"
+    "ldp x24, x23, [x0, #0x20]\n"
+    "ldp x22, x25, [x0, #0x30]\n"
+    "usubl v0.8h, v0.8b, v15.8b\n"
+    "usubl v1.8h, v1.8b, v15.8b\n"
+    "ldp x20, x19, [x0, #0x40]\n"
+    "ldr d31, [x28, x1]\n"
+    "usubl v2.8h, v2.8b, v15.8b\n"
+    "usubl v3.8h, v3.8b, v15.8b\n"
+    "ldr d30, [x27, x1]\n"
+    "ldr d29, [x10, x1]\n"
+    "usubl v4.8h, v4.8b, v15.8b\n"
+    "ushll v31.8h, v31.8b, #0x0\n"
+    "ldr d28, [x26, x1]\n"
+    "ldr d27, [x24, x1]\n"
+    "ushll v30.8h, v30.8b, #0x0\n"
+    "ushll v29.8h, v29.8b, #0x0\n"
+    "ldr d23, [x23, x1]\n"
+    "ldr d25, [x22, x1]\n"
+    "ushll v28.8h, v28.8b, #0x0\n"
+    "ushll v27.8h, v27.8b, #0x0\n"
+    "ldr d24, [x25, x1]\n"
+    "ldr d26, [x20, x1]\n"
+    "ushll v23.8h, v23.8b, #0x0\n"
+    "ushll v25.8h, v25.8b, #0x0\n"
+    "ldr d22, [x19, x1]\n"
+    "ushll v24.8h, v24.8b, #0x0\n"
+    "ushll v26.8h, v26.8b, #0x0\n"
+    "ushll v22.8h, v22.8b, #0x0\n"
+    "bgt 1b\n"
+    "2:"  // Tail
+    "smlal v11.4s, v31.4h, v0.4h\n"
+    "smlal2 v21.4s, v31.8h, v0.8h\n"
+    "ldr x19, [x0, #0x50]\n"
+    "ldr d31, [x19, x1]\n"
+    "smlal v14.4s, v30.4h, v0.4h\n"
+    "smlal v9.4s, v29.4h, v0.4h\n"
+    "ldr x20, [x0, #0x58]\n"
+    "ushll v31.8h, v31.8b, #0x0\n"
+    "smlal v7.4s, v28.4h, v0.4h\n"
+    "smlal2 v10.4s, v30.8h, v0.8h\n"
+    "ldr x19, [x0, #0x60]\n"
+    "ldr x24, [x0, #0x68]\n"
+    "smlal2 v8.4s, v29.8h, v0.8h\n"
+    "smlal v11.4s, v30.4h, v1.4h\n"
+    "ldr x23, [x0, #0x70]\n"
+    "ldr x26, [x0, #0x78]\n"
+    "smlal2 v21.4s, v30.8h, v1.8h\n"
+    "smlal2 v6.4s, v28.8h, v0.8h\n"
+    "ldr d30, [x20, x1]\n"
+    "ushll v30.8h, v30.8b, #0x0\n"
+    "smlal v14.4s, v27.4h, v1.4h\n"
+    "smlal v9.4s, v28.4h, v1.4h\n"
+    "ldr d0, [x8, #0x28]\n"
+    "usubl v0.8h, v0.8b, v15.8b\n"
+    "smlal v7.4s, v23.4h, v1.4h\n"
+    "smlal2 v10.4s, v27.8h, v1.8h\n"
+    "ldr x7, [x0, #0x80]\n"
+    "ldr x22, [x0, #0x88]\n"
+    "smlal2 v8.4s, v28.8h, v1.8h\n"
+    "smlal v11.4s, v27.4h, v2.4h\n"
+    "ldr x20, [x0, #0x90]\n"
+    "ldr x14, [x0, #0x98]\n"
+    "smlal2 v21.4s, v27.8h, v2.8h\n"
+    "smlal2 v6.4s, v23.8h, v1.8h\n"
+    "ldr d27, [x19, x1]\n"
+    "ushll v27.8h, v27.8b, #0x0\n"
+    "smlal v14.4s, v25.4h, v2.4h\n"
+    "smlal v9.4s, v23.4h, v2.4h\n"
+    "ldr d1, [x8, #0x30]\n"
+    "usubl v1.8h, v1.8b, v15.8b\n"
+    "smlal v7.4s, v31.4h, v2.4h\n"
+    "smlal2 v10.4s, v25.8h, v2.8h\n"
+    "ldr x19, [x0, #0xa0]\n"
+    "ldr x13, [x0, #0xa8]\n"
+    "smlal2 v8.4s, v23.8h, v2.8h\n"
+    "smlal v11.4s, v25.4h, v3.4h\n"
+    "ldr x12, [x0, #0xb0]\n"
+    "ldr x11, [x0, #0xb8]\n"
+    "smlal2 v21.4s, v25.8h, v3.8h\n"
+    "smlal2 v6.4s, v31.8h, v2.8h\n"
+    "ldr d25, [x24, x1]\n"
+    "ushll v25.8h, v25.8b, #0x0\n"
+    "smlal v14.4s, v24.4h, v3.4h\n"
+    "smlal v9.4s, v31.4h, v3.4h\n"
+    "ldr d2, [x8, #0x38]\n"
+    "usubl v2.8h, v2.8b, v15.8b\n"
+    "smlal v7.4s, v30.4h, v3.4h\n"
+    "smlal2 v10.4s, v24.8h, v3.8h\n"
+    "ldr x10, [x0, #0xc0]\n"
+    "ldr x9, [x0, #0xc8]\n"
+    "smlal2 v8.4s, v31.8h, v3.8h\n"
+    "smlal v11.4s, v24.4h, v4.4h\n"
+    "ldr x28, [x0, #0xd0]\n"
+    "ldr x27, [x0, #0xd8]\n"
+    "smlal2 v21.4s, v24.8h, v4.8h\n"
+    "smlal2 v6.4s, v30.8h, v3.8h\n"
+    "ldr d24, [x23, x1]\n"
+    "ushll v24.8h, v24.8b, #0x0\n"
+    "smlal v14.4s, v27.4h, v4.4h\n"
+    "smlal v9.4s, v30.4h, v4.4h\n"
+    "ldr d3, [x8, #0x40]\n"
+    "usubl v3.8h, v3.8b, v15.8b\n"
+    "smlal v7.4s, v26.4h, v4.4h\n"
+    "smlal2 v10.4s, v27.8h, v4.8h\n"
+    "ldr d27, [x26, x1]\n"
+    "ushll v27.8h, v27.8b, #0x0\n"
+    "smlal2 v8.4s, v30.8h, v4.8h\n"
+    "smlal v11.4s, v29.4h, v0.4h\n"
+    "ldr x26, [x0, #0xe0]\n"
+    "ldr x25, [x0, #0xe8]\n"
+    "smlal2 v21.4s, v29.8h, v0.8h\n"
+    "smlal2 v6.4s, v26.8h, v4.8h\n"
+    "ldr d4, [x8, #0x48]\n"
+    "usubl v4.8h, v4.8b, v15.8b\n"
+    "smlal v14.4s, v28.4h, v0.4h\n"
+    "smlal v9.4s, v22.4h, v0.4h\n"
+    "ldr x24, [x0, #0xf0]\n"
+    "ldr x23, [x0, #0xf8]\n"
+    "smlal v7.4s, v25.4h, v0.4h\n"
+    "smlal2 v10.4s, v28.8h, v0.8h\n"
+    "ldr q17, [x5, #0x0]\n"
+    "ldr q5, [x6, #0x0]\n"
+    "smlal2 v8.4s, v22.8h, v0.8h\n"
+    "smlal v11.4s, v28.4h, v1.4h\n"
+    "ldr q18, [x5, #0x10]\n"
+    "ldr q29, [x6, #0x10]\n"
+    "smlal2 v21.4s, v28.8h, v1.8h\n"
+    "smlal2 v6.4s, v25.8h, v0.8h\n"
+    "ldr d28, [x22, x1]\n"
+    "ldr d0, [x8, #0x50]\n"
+    "smlal v14.4s, v23.4h, v1.4h\n"
+    "smlal v9.4s, v25.4h, v1.4h\n"
+    "ushll v28.8h, v28.8b, #0x0\n"
+    "ldr x22, [x0, #0x100]\n"
+    "smlal v7.4s, v24.4h, v1.4h\n"
+    "smlal2 v10.4s, v23.8h, v1.8h\n"
+    "usubl v0.8h, v0.8b, v15.8b\n"
+    "tst x4, #0x7\n"
+    "smlal2 v8.4s, v25.8h, v1.8h\n"
+    "smlal v11.4s, v23.4h, v2.4h\n"
+    "add x5, x5, #0x20\n"
+    "add x6, x6, #0x20\n"
+    "smlal2 v21.4s, v23.8h, v2.8h\n"
+    "ldr d23, [x7, x1]\n"
+    "smlal2 v6.4s, v24.8h, v1.8h\n"
+    "ushll v23.8h, v23.8b, #0x0\n"
+    "smlal v14.4s, v31.4h, v2.4h\n"
+    "smlal v9.4s, v24.4h, v2.4h\n"
+    "ldr d1, [x8, #0x58]\n"
+    "usubl v1.8h, v1.8b, v15.8b\n"
+    "smlal v7.4s, v27.4h, v2.4h\n"
+    "smlal2 v10.4s, v31.8h, v2.8h\n"
+    "ldr x7, [x0, #0x108]\n"
+    "smlal2 v8.4s, v24.8h, v2.8h\n"
+    "smlal v11.4s, v31.4h, v3.4h\n"
+    "smlal2 v21.4s, v31.8h, v3.8h\n"
+    "smlal2 v6.4s, v27.8h, v2.8h\n"
+    "ldr d31, [x20, x1]\n"
+    "ushll v31.8h, v31.8b, #0x0\n"
+    "smlal v14.4s, v30.4h, v3.4h\n"
+    "smlal v9.4s, v27.4h, v3.4h\n"
+    "ldr d2, [x8, #0x60]\n"
+    "usubl v2.8h, v2.8b, v15.8b\n"
+    "smlal v7.4s, v23.4h, v3.4h\n"
+    "smlal2 v10.4s, v30.8h, v3.8h\n"
+    "ldr x20, [x0, #0x110]\n"
+    "smlal2 v8.4s, v27.8h, v3.8h\n"
+    "smlal v11.4s, v30.4h, v4.4h\n"
+    "smlal2 v21.4s, v30.8h, v4.8h\n"
+    "ldr d30, [x14, x1]\n"
+    "smlal2 v6.4s, v23.8h, v3.8h\n"
+    "ushll v30.8h, v30.8b, #0x0\n"
+    "smlal v14.4s, v26.4h, v4.4h\n"
+    "smlal v9.4s, v23.4h, v4.4h\n"
+    "ldr d3, [x8, #0x68]\n"
+    "usubl v3.8h, v3.8b, v15.8b\n"
+    "smlal v7.4s, v28.4h, v4.4h\n"
+    "smlal2 v10.4s, v26.8h, v4.8h\n"
+    "ldr d26, [x19, x1]\n"
+    "ushll v26.8h, v26.8b, #0x0\n"
+    "smlal2 v8.4s, v23.8h, v4.8h\n"
+    "smlal v11.4s, v22.4h, v0.4h\n"
+    "ldr x19, [x0, #0x118]\n"
+    "smlal2 v21.4s, v22.8h, v0.8h\n"
+    "smlal2 v6.4s, v28.8h, v4.8h\n"
+    "ldr d4, [x8, #0x70]\n"
+    "ldr d22, [x11, x1]\n"
+    "smlal v14.4s, v25.4h, v0.4h\n"
+    "smlal v9.4s, v31.4h, v0.4h\n"
+    "usubl v4.8h, v4.8b, v15.8b\n"
+    "smlal v7.4s, v30.4h, v0.4h\n"
+    "smlal2 v10.4s, v25.8h, v0.8h\n"
+    "ushll v22.8h, v22.8b, #0x0\n"
+    "smlal2 v8.4s, v31.8h, v0.8h\n"
+    "smlal v11.4s, v25.4h, v1.4h\n"
+    "smlal2 v21.4s, v25.8h, v1.8h\n"
+    "ldr d25, [x13, x1]\n"
+    "smlal2 v6.4s, v30.8h, v0.8h\n"
+    "ushll v25.8h, v25.8b, #0x0\n"
+    "smlal v14.4s, v24.4h, v1.4h\n"
+    "smlal v9.4s, v30.4h, v1.4h\n"
+    "ldr d0, [x8, #0x78]\n"
+    "usubl v0.8h, v0.8b, v15.8b\n"
+    "smlal v7.4s, v26.4h, v1.4h\n"
+    "smlal2 v10.4s, v24.8h, v1.8h\n"
+    "smlal2 v8.4s, v30.8h, v1.8h\n"
+    "smlal v11.4s, v24.4h, v2.4h\n"
+    "smlal2 v21.4s, v24.8h, v2.8h\n"
+    "ldr d24, [x12, x1]\n"
+    "smlal2 v6.4s, v26.8h, v1.8h\n"
+    "ushll v24.8h, v24.8b, #0x0\n"
+    "smlal v14.4s, v27.4h, v2.4h\n"
+    "smlal v9.4s, v26.4h, v2.4h\n"
+    "ldr d1, [x8, #0x80]\n"
+    "usubl v1.8h, v1.8b, v15.8b\n"
+    "smlal v7.4s, v25.4h, v2.4h\n"
+    "smlal2 v10.4s, v27.8h, v2.8h\n"
+    "smlal2 v8.4s, v26.8h, v2.8h\n"
+    "smlal v11.4s, v27.4h, v3.4h\n"
+    "smlal2 v21.4s, v27.8h, v3.8h\n"
+    "smlal2 v6.4s, v25.8h, v2.8h\n"
+    "ldr d27, [x10, x1]\n"
+    "ushll v27.8h, v27.8b, #0x0\n"
+    "smlal v14.4s, v23.4h, v3.4h\n"
+    "smlal v9.4s, v25.4h, v3.4h\n"
+    "ldr d2, [x8, #0x88]\n"
+    "usubl v2.8h, v2.8b, v15.8b\n"
+    "smlal v7.4s, v24.4h, v3.4h\n"
+    "smlal2 v10.4s, v23.8h, v3.8h\n"
+    "smlal2 v8.4s, v25.8h, v3.8h\n"
+    "smlal v11.4s, v23.4h, v4.4h\n"
+    "smlal2 v21.4s, v23.8h, v4.8h\n"
+    "ldr d23, [x9, x1]\n"
+    "smlal2 v6.4s, v24.8h, v3.8h\n"
+    "ushll v23.8h, v23.8b, #0x0\n"
+    "smlal v14.4s, v28.4h, v4.4h\n"
+    "smlal v9.4s, v24.4h, v4.4h\n"
+    "ldr d3, [x8, #0x90]\n"
+    "usubl v3.8h, v3.8b, v15.8b\n"
+    "smlal v7.4s, v22.4h, v4.4h\n"
+    "smlal2 v10.4s, v28.8h, v4.8h\n"
+    "ldr d28, [x26, x1]\n"
+    "ushll v28.8h, v28.8b, #0x0\n"
+    "smlal2 v8.4s, v24.8h, v4.8h\n"
+    "smlal v11.4s, v31.4h, v0.4h\n"
+    "smlal2 v21.4s, v31.8h, v0.8h\n"
+    "ldr d31, [x28, x1]\n"
+    "smlal2 v6.4s, v22.8h, v4.8h\n"
+    "ushll v31.8h, v31.8b, #0x0\n"
+    "smlal v14.4s, v30.4h, v0.4h\n"
+    "smlal v9.4s, v27.4h, v0.4h\n"
+    "ldr d4, [x8, #0x98]\n"
+    "usubl v4.8h, v4.8b, v15.8b\n"
+    "smlal v7.4s, v23.4h, v0.4h\n"
+    "smlal2 v10.4s, v30.8h, v0.8h\n"
+    "smlal2 v8.4s, v27.8h, v0.8h\n"
+    "smlal v11.4s, v30.4h, v1.4h\n"
+    "smlal2 v21.4s, v30.8h, v1.8h\n"
+    "ldr d30, [x27, x1]\n"
+    "smlal2 v6.4s, v23.8h, v0.8h\n"
+    "ushll v30.8h, v30.8b, #0x0\n"
+    "smlal v14.4s, v26.4h, v1.4h\n"
+    "smlal v9.4s, v23.4h, v1.4h\n"
+    "ldr d0, [x8, #0xa0]\n"
+    "usubl v0.8h, v0.8b, v15.8b\n"
+    "smlal v7.4s, v31.4h, v1.4h\n"
+    "smlal2 v10.4s, v26.8h, v1.8h\n"
+    "smlal2 v8.4s, v23.8h, v1.8h\n"
+    "smlal v11.4s, v26.4h, v2.4h\n"
+    "smlal2 v21.4s, v26.8h, v2.8h\n"
+    "smlal2 v6.4s, v31.8h, v1.8h\n"
+    "ldr d26, [x25, x1]\n"
+    "ushll v26.8h, v26.8b, #0x0\n"
+    "smlal v14.4s, v25.4h, v2.4h\n"
+    "smlal v9.4s, v31.4h, v2.4h\n"
+    "ldr d1, [x8, #0xa8]\n"
+    "usubl v1.8h, v1.8b, v15.8b\n"
+    "smlal v7.4s, v30.4h, v2.4h\n"
+    "smlal2 v10.4s, v25.8h, v2.8h\n"
+    "smlal2 v8.4s, v31.8h, v2.8h\n"
+    "smlal v11.4s, v25.4h, v3.4h\n"
+    "smlal2 v21.4s, v25.8h, v3.8h\n"
+    "smlal2 v6.4s, v30.8h, v2.8h\n"
+    "ldr d25, [x24, x1]\n"
+    "ushll v25.8h, v25.8b, #0x0\n"
+    "smlal v14.4s, v24.4h, v3.4h\n"
+    "smlal v9.4s, v30.4h, v3.4h\n"
+    "ldr d2, [x8, #0xb0]\n"
+    "usubl v2.8h, v2.8b, v15.8b\n"
+    "smlal v7.4s, v28.4h, v3.4h\n"
+    "smlal2 v10.4s, v24.8h, v3.8h\n"
+    "smlal2 v8.4s, v30.8h, v3.8h\n"
+    "smlal v11.4s, v24.4h, v4.4h\n"
+    "smlal2 v21.4s, v24.8h, v4.8h\n"
+    "ldr d24, [x23, x1]\n"
+    "smlal2 v6.4s, v28.8h, v3.8h\n"
+    "ushll v24.8h, v24.8b, #0x0\n"
+    "smlal v14.4s, v22.4h, v4.4h\n"
+    "smlal v9.4s, v28.4h, v4.4h\n"
+    "ldr d3, [x8, #0xb8]\n"
+    "usubl v3.8h, v3.8b, v15.8b\n"
+    "smlal v7.4s, v26.4h, v4.4h\n"
+    "smlal2 v8.4s, v28.8h, v4.8h\n"
+    "smlal v11.4s, v27.4h, v0.4h\n"
+    "smlal2 v21.4s, v27.8h, v0.8h\n"
+    "ldr d27, [x22, x1]\n"
+    "ushll v27.8h, v27.8b, #0x0\n"
+    "smlal2 v10.4s, v22.8h, v4.8h\n"
+    "smlal2 v6.4s, v26.8h, v4.8h\n"
+    "ldr d4, [x8, #0xc0]\n"
+    "usubl v4.8h, v4.8b, v15.8b\n"
+    "smlal v14.4s, v23.4h, v0.4h\n"
+    "smlal v9.4s, v25.4h, v0.4h\n"
+    "smlal v7.4s, v24.4h, v0.4h\n"
+    "smlal2 v8.4s, v25.8h, v0.8h\n"
+    "ldr d25, [x7, x1]\n"
+    "ushll v25.8h, v25.8b, #0x0\n"
+    "smlal2 v10.4s, v23.8h, v0.8h\n"
+    "smlal2 v6.4s, v24.8h, v0.8h\n"
+    "smlal v11.4s, v23.4h, v1.4h\n"
+    "smlal v14.4s, v31.4h, v1.4h\n"
+    "smlal v9.4s, v24.4h, v1.4h\n"
+    "smlal v7.4s, v27.4h, v1.4h\n"
+    "smlal2 v8.4s, v24.8h, v1.8h\n"
+    "ldr d24, [x20, x1]\n"
+    "smlal2 v21.4s, v23.8h, v1.8h\n"
+    "ushll v24.8h, v24.8b, #0x0\n"
+    "smlal2 v10.4s, v31.8h, v1.8h\n"
+    "smlal2 v6.4s, v27.8h, v1.8h\n"
+    "smlal v11.4s, v31.4h, v2.4h\n"
+    "smlal v14.4s, v30.4h, v2.4h\n"
+    "smlal v9.4s, v27.4h, v2.4h\n"
+    "smlal v7.4s, v25.4h, v2.4h\n"
+    "smlal2 v8.4s, v27.8h, v2.8h\n"
+    "ldr d27, [x19, x1]\n"
+    "smlal2 v21.4s, v31.8h, v2.8h\n"
+    "ushll v27.8h, v27.8b, #0x0\n"
+    "smlal2 v10.4s, v30.8h, v2.8h\n"
+    "smlal2 v6.4s, v25.8h, v2.8h\n"
+    "add x1, x1, #0x8\n"
+    "smlal v11.4s, v30.4h, v3.4h\n"
+    "smlal v14.4s, v28.4h, v3.4h\n"
+    "smlal v9.4s, v25.4h, v3.4h\n"
+    "smlal v7.4s, v24.4h, v3.4h\n"
+    "smlal2 v21.4s, v30.8h, v3.8h\n"
+    "smlal2 v10.4s, v28.8h, v3.8h\n"
+    "smlal2 v8.4s, v25.8h, v3.8h\n"
+    "smlal2 v6.4s, v24.8h, v3.8h\n"
+    "smlal v11.4s, v28.4h, v4.4h\n"
+    "smlal v14.4s, v26.4h, v4.4h\n"
+    "sqdmulh v11.4s, v11.4s, v17.4s\n"
+    "smlal v9.4s, v24.4h, v4.4h\n"
+    "smlal v7.4s, v27.4h, v4.4h\n"
+    "sqdmulh v14.4s, v14.4s, v17.4s\n"
+    "smlal2 v21.4s, v28.8h, v4.8h\n"
+    "smlal2 v10.4s, v26.8h, v4.8h\n"
+    "sqdmulh v9.4s, v9.4s, v17.4s\n"
+    "smlal2 v8.4s, v24.8h, v4.8h\n"
+    "smlal2 v6.4s, v27.8h, v4.8h\n"
+    "sqdmulh v7.4s, v7.4s, v17.4s\n"
+    "and v23.16b, v11.16b, v5.16b\n"
+    "sqdmulh v21.4s, v21.4s, v18.4s\n"
+    "and v22.16b, v14.16b, v5.16b\n"
+    "sqdmulh v10.4s, v10.4s, v18.4s\n"
+    "and v17.16b, v9.16b, v5.16b\n"
+    "sqdmulh v8.4s, v8.4s, v18.4s\n"
+    "and v20.16b, v7.16b, v5.16b\n"
+    "sqdmulh v6.4s, v6.4s, v18.4s\n"
+    "sshr v23.4s, v23.4s, #0x1f\n"
+    "and v19.16b, v21.16b, v29.16b\n"
+    "sshr v22.4s, v22.4s, #0x1f\n"
+    "and v18.16b, v10.16b, v29.16b\n"
+    "sshr v17.4s, v17.4s, #0x1f\n"
+    "and v26.16b, v8.16b, v29.16b\n"
+    "sshr v20.4s, v20.4s, #0x1f\n"
+    "and v4.16b, v6.16b, v29.16b\n"
+    "sqadd v11.4s, v11.4s, v23.4s\n"
+    "sshr v19.4s, v19.4s, #0x1f\n"
+    "sqadd v14.4s, v14.4s, v22.4s\n"
+    "sshr v18.4s, v18.4s, #0x1f\n"
+    "sqadd v9.4s, v9.4s, v17.4s\n"
+    "sshr v26.4s, v26.4s, #0x1f\n"
+    "sqadd v7.4s, v7.4s, v20.4s\n"
+    "sshr v4.4s, v4.4s, #0x1f\n"
+    "srshl v11.4s, v11.4s, v5.4s\n"
+    "sqadd v21.4s, v21.4s, v19.4s\n"
+    "srshl v14.4s, v14.4s, v5.4s\n"
+    "sqadd v10.4s, v10.4s, v18.4s\n"
+    "srshl v9.4s, v9.4s, v5.4s\n"
+    "sqadd v8.4s, v8.4s, v26.4s\n"
+    "srshl v7.4s, v7.4s, v5.4s\n"
+    "sqadd v6.4s, v6.4s, v4.4s\n"
+    "srshl v21.4s, v21.4s, v29.4s\n"
+    "sqxtn v11.4h, v11.4s\n"
+    "srshl v10.4s, v10.4s, v29.4s\n"
+    "sqxtn v14.4h, v14.4s\n"
+    "srshl v8.4s, v8.4s, v29.4s\n"
+    "sqxtn v9.4h, v9.4s\n"
+    "srshl v6.4s, v6.4s, v29.4s\n"
+    "sqxtn v7.4h, v7.4s\n"
+    "sqxtn2 v11.8h, v21.4s\n"
+    "sqxtn2 v14.8h, v10.4s\n"
+    "sqxtn2 v9.8h, v8.4s\n"
+    "sqxtn2 v7.8h, v6.4s\n"
+    "sqadd v11.8h, v11.8h, v16.8h\n"
+    "sqadd v14.8h, v14.8h, v16.8h\n"
+    "sqadd v9.8h, v9.8h, v16.8h\n"
+    "sqadd v7.8h, v7.8h, v16.8h\n"
+    "smax v11.8h, v11.8h, v12.8h\n"
+    "smax v14.8h, v14.8h, v12.8h\n"
+    "smax v9.8h, v9.8h, v12.8h\n"
+    "smax v7.8h, v7.8h, v12.8h\n"
+    "smin v11.8h, v11.8h, v13.8h\n"
+    "smin v14.8h, v14.8h, v13.8h\n"
+    "smin v9.8h, v9.8h, v13.8h\n"
+    "smin v7.8h, v7.8h, v13.8h\n"
+    "uzp1 v11.16b, v11.16b, v11.16b\n"
+    "uzp1 v14.16b, v14.16b, v14.16b\n"
+    "str d11, [x21, x2]\n"
+    "uzp1 v9.16b, v9.16b, v9.16b\n"
+    "uzp1 v7.16b, v7.16b, v7.16b\n"
+    "str d14, [x15, x2]\n"
+    "str d9, [x17, x2]\n"
+    "str d7, [x16, x2]\n"
+    "add x2, x2, #0x8\n"
+    "beq 124f\n"
+    "add x8, x8, #0xc8\n"
+    "3:"  // Oddments
+    "ldr x19, [%x[params], %[offsetof_Params_bias]]\n"
+    "tbz x4, #2, 5f\n"
+    "ld1 { v11.4s }, [x19], #0x10\n"
+    "tbz x4, #1, 4f\n"
+    "ld1 { v21.d }[0], [x19], #0x8\n"
+    "tbz x4, #0, 7f\n"
+    "ld1 { v21.s }[2], [x19]\n"
+    "b 7f\n"
+    "4:"  // Oddments: Load bias: Bit 2: Bit 1: Unset
+    "tbz x4, #0, 7f\n"
+    "ld1 { v21.s }[0], [x19]\n"
+    "b 7f\n"
+    "5:"  // Oddments: Load bias: Bit 2: Unset
+    "tbz x4, #1, 6f\n"
+    "ld1 { v11.d }[0], [x19], #0x8\n"
+    "tbz x4, #0, 7f\n"
+    "ld1 { v11.s }[2], [x19]\n"
+    "b 7f\n"
+    "6:"  // Oddments: Load bias: Bit 2: Unset: Bit 1: Unset
+    "tbz x4, #0, 7f\n"
+    "ld1 { v11.s }[0], [x19]\n"
+    "7:"  // Oddments: Load bias: Bit 2: End
+    "ldr d0, [x8, #0x0]\n"
+    "ldr d1, [x8, #0x8]\n"
+    "mov v14.16b, v11.16b\n"
+    "mov v10.16b, v21.16b\n"
+    "ldr d2, [x8, #0x10]\n"
+    "ldr d3, [x8, #0x18]\n"
+    "mov v9.16b, v11.16b\n"
+    "mov v8.16b, v21.16b\n"
+    "ldr d4, [x8, #0x20]\n"
+    "ldp x28, x27, [x0, #0x0]\n"
+    "mov v7.16b, v11.16b\n"
+    "mov v6.16b, v21.16b\n"
+    "ldp x10, x26, [x0, #0x10]\n"
+    "ldp x24, x23, [x0, #0x20]\n"
+    "usubl v0.8h, v0.8b, v15.8b\n"
+    "usubl v1.8h, v1.8b, v15.8b\n"
+    "ldp x22, x25, [x0, #0x30]\n"
+    "ldp x20, x19, [x0, #0x40]\n"
+    "usubl v2.8h, v2.8b, v15.8b\n"
+    "usubl v3.8h, v3.8b, v15.8b\n"
+    "usubl v4.8h, v4.8b, v15.8b\n"
+    "add x28, x28, x1\n"
+    "add x27, x27, x1\n"
+    "add x10, x10, x1\n"
+    "add x26, x26, x1\n"
+    "add x24, x24, x1\n"
+    "add x23, x23, x1\n"
+    "add x22, x22, x1\n"
+    "add x25, x25, x1\n"
+    "add x20, x20, x1\n"
+    "add x19, x19, x1\n"
+    "tbz x4, #2, 9f\n"
+    "ld1 { v31.s }[0], [x28], #0x4\n"
+    "ld1 { v30.s }[0], [x27], #0x4\n"
+    "ld1 { v29.s }[0], [x10], #0x4\n"
+    "ld1 { v28.s }[0], [x26], #0x4\n"
+    "ld1 { v27.s }[0], [x24], #0x4\n"
+    "ld1 { v23.s }[0], [x23], #0x4\n"
+    "ld1 { v25.s }[0], [x22], #0x4\n"
+    "ld1 { v24.s }[0], [x25], #0x4\n"
+    "ld1 { v26.s }[0], [x20], #0x4\n"
+    "ld1 { v22.s }[0], [x19], #0x4\n"
+    "tbz x4, #1, 8f\n"
+    "ld1 { v31.h }[2], [x28], #0x2\n"
+    "ld1 { v30.h }[2], [x27], #0x2\n"
+    "ld1 { v29.h }[2], [x10], #0x2\n"
+    "ld1 { v28.h }[2], [x26], #0x2\n"
+    "ld1 { v27.h }[2], [x24], #0x2\n"
+    "ld1 { v23.h }[2], [x23], #0x2\n"
+    "ld1 { v25.h }[2], [x22], #0x2\n"
+    "ld1 { v24.h }[2], [x25], #0x2\n"
+    "ld1 { v26.h }[2], [x20], #0x2\n"
+    "ld1 { v22.h }[2], [x19], #0x2\n"
+    "tbz x4, #0, 11f\n"
+    "ld1 { v31.b }[6], [x28]\n"
+    "ld1 { v30.b }[6], [x27]\n"
+    "ld1 { v29.b }[6], [x10]\n"
+    "ld1 { v28.b }[6], [x26]\n"
+    "ld1 { v27.b }[6], [x24]\n"
+    "ld1 { v23.b }[6], [x23]\n"
+    "ld1 { v25.b }[6], [x22]\n"
+    "ld1 { v24.b }[6], [x25]\n"
+    "ld1 { v26.b }[6], [x20]\n"
+    "ld1 { v22.b }[6], [x19]\n"
+    "b 11f\n"
+    "8:"  // Oddments: Initial loads: Bit 2: Bit 1: Unset
+    "tbz x4, #0, 11f\n"
+    "ld1 { v31.b }[4], [x28]\n"
+    "ld1 { v30.b }[4], [x27]\n"
+    "ld1 { v29.b }[4], [x10]\n"
+    "ld1 { v28.b }[4], [x26]\n"
+    "ld1 { v27.b }[4], [x24]\n"
+    "ld1 { v23.b }[4], [x23]\n"
+    "ld1 { v25.b }[4], [x22]\n"
+    "ld1 { v24.b }[4], [x25]\n"
+    "ld1 { v26.b }[4], [x20]\n"
+    "ld1 { v22.b }[4], [x19]\n"
+    "b 11f\n"
+    "9:"  // Oddments: Initial loads: Bit 2: Unset
+    "tbz x4, #1, 10f\n"
+    "ld1 { v31.h }[0], [x28], #0x2\n"
+    "ld1 { v30.h }[0], [x27], #0x2\n"
+    "ld1 { v29.h }[0], [x10], #0x2\n"
+    "ld1 { v28.h }[0], [x26], #0x2\n"
+    "ld1 { v27.h }[0], [x24], #0x2\n"
+    "ld1 { v23.h }[0], [x23], #0x2\n"
+    "ld1 { v25.h }[0], [x22], #0x2\n"
+    "ld1 { v24.h }[0], [x25], #0x2\n"
+    "ld1 { v26.h }[0], [x20], #0x2\n"
+    "ld1 { v22.h }[0], [x19], #0x2\n"
+    "tbz x4, #0, 11f\n"
+    "ld1 { v31.b }[2], [x28]\n"
+    "ld1 { v30.b }[2], [x27]\n"
+    "ld1 { v29.b }[2], [x10]\n"
+    "ld1 { v28.b }[2], [x26]\n"
+    "ld1 { v27.b }[2], [x24]\n"
+    "ld1 { v23.b }[2], [x23]\n"
+    "ld1 { v25.b }[2], [x22]\n"
+    "ld1 { v24.b }[2], [x25]\n"
+    "ld1 { v26.b }[2], [x20]\n"
+    "ld1 { v22.b }[2], [x19]\n"
+    "b 11f\n"
+    "10:"  // Oddments: Initial loads: Bit 2: Unset: Bit 1: Unset
+    "tbz x4, #0, 11f\n"
+    "ld1 { v31.b }[0], [x28]\n"
+    "ld1 { v30.b }[0], [x27]\n"
+    "ld1 { v29.b }[0], [x10]\n"
+    "ld1 { v28.b }[0], [x26]\n"
+    "ld1 { v27.b }[0], [x24]\n"
+    "ld1 { v23.b }[0], [x23]\n"
+    "ld1 { v25.b }[0], [x22]\n"
+    "ld1 { v24.b }[0], [x25]\n"
+    "ld1 { v26.b }[0], [x20]\n"
+    "ld1 { v22.b }[0], [x19]\n"
+    "11:"  // Oddments: Initial loads: Bit 2: End
+    "ushll v31.8h, v31.8b, #0x0\n"
+    "ushll v30.8h, v30.8b, #0x0\n"
+    "smlal v11.4s, v31.4h, v0.4h\n"
+    "ldr x19, [x0, #0x50]\n"
+    "ushll v29.8h, v29.8b, #0x0\n"
+    "smlal2 v21.4s, v31.8h, v0.8h\n"
+    "smlal v14.4s, v30.4h, v0.4h\n"
+    "smlal2 v10.4s, v30.8h, v0.8h\n"
+    "smlal v9.4s, v29.4h, v0.4h\n"
+    "ushll v28.8h, v28.8b, #0x0\n"
+    "add x19, x19, x1\n"
+    "smlal2 v8.4s, v29.8h, v0.8h\n"
+    "ushll v27.8h, v27.8b, #0x0\n"
+    "smlal v7.4s, v28.4h, v0.4h\n"
+    "smlal2 v6.4s, v28.8h, v0.8h\n"
+    "smlal v11.4s, v30.4h, v1.4h\n"
+    "ushll v23.8h, v23.8b, #0x0\n"
+    "smlal2 v21.4s, v30.8h, v1.8h\n"
+    "smlal v14.4s, v27.4h, v1.4h\n"
+    "ushll v25.8h, v25.8b, #0x0\n"
+    "smlal2 v10.4s, v27.8h, v1.8h\n"
+    "smlal v9.4s, v28.4h, v1.4h\n"
+    "ushll v24.8h, v24.8b, #0x0\n"
+    "smlal2 v8.4s, v28.8h, v1.8h\n"
+    "ushll v26.8h, v26.8b, #0x0\n"
+    "smlal v7.4s, v23.4h, v1.4h\n"
+    "ushll v22.8h, v22.8b, #0x0\n"
+    "smlal2 v6.4s, v23.8h, v1.8h\n"
+    "smlal v11.4s, v27.4h, v2.4h\n"
+    "smlal2 v21.4s, v27.8h, v2.8h\n"
+    "smlal v14.4s, v25.4h, v2.4h\n"
+    "smlal2 v10.4s, v25.8h, v2.8h\n"
+    "smlal v9.4s, v23.4h, v2.4h\n"
+    "smlal2 v8.4s, v23.8h, v2.8h\n"
+    "tbz x4, #2, 13f\n"
+    "ld1 { v31.s }[0], [x19], #0x4\n"
+    "tbz x4, #1, 12f\n"
+    "ld1 { v31.h }[2], [x19], #0x2\n"
+    "tbz x4, #0, 15f\n"
+    "ld1 { v31.b }[6], [x19]\n"
+    "b 15f\n"
+    "12:"  // Oddments: Load (1, 3): Bit 2: Bit 1: Unset
+    "tbz x4, #0, 15f\n"
+    "ld1 { v31.b }[4], [x19]\n"
+    "b 15f\n"
+    "13:"  // Oddments: Load (1, 3): Bit 2: Unset
+    "tbz x4, #1, 14f\n"
+    "ld1 { v31.h }[0], [x19], #0x2\n"
+    "tbz x4, #0, 15f\n"
+    "ld1 { v31.b }[2], [x19]\n"
+    "b 15f\n"
+    "14:"  // Oddments: Load (1, 3): Bit 2: Unset: Bit 1: Unset
+    "tbz x4, #0, 15f\n"
+    "ld1 { v31.b }[0], [x19]\n"
+    "15:"  // Oddments: Load (1, 3): Bit 2: End
+    "ushll v31.8h, v31.8b, #0x0\n"
+    "ldr x20, [x0, #0x58]\n"
+    "smlal v7.4s, v31.4h, v2.4h\n"
+    "smlal2 v6.4s, v31.8h, v2.8h\n"
+    "smlal v11.4s, v25.4h, v3.4h\n"
+    "smlal2 v21.4s, v25.8h, v3.8h\n"
+    "add x20, x20, x1\n"
+    "smlal v14.4s, v24.4h, v3.4h\n"
+    "smlal2 v10.4s, v24.8h, v3.8h\n"
+    "smlal v9.4s, v31.4h, v3.4h\n"
+    "smlal2 v8.4s, v31.8h, v3.8h\n"
+    "tbz x4, #2, 17f\n"
+    "ld1 { v30.s }[0], [x20], #0x4\n"
+    "tbz x4, #1, 16f\n"
+    "ld1 { v30.h }[2], [x20], #0x2\n"
+    "tbz x4, #0, 19f\n"
+    "ld1 { v30.b }[6], [x20]\n"
+    "b 19f\n"
+    "16:"  // Oddments: Load (1, 4): Bit 2: Bit 1: Unset
+    "tbz x4, #0, 19f\n"
+    "ld1 { v30.b }[4], [x20]\n"
+    "b 19f\n"
+    "17:"  // Oddments: Load (1, 4): Bit 2: Unset
+    "tbz x4, #1, 18f\n"
+    "ld1 { v30.h }[0], [x20], #0x2\n"
+    "tbz x4, #0, 19f\n"
+    "ld1 { v30.b }[2], [x20]\n"
+    "b 19f\n"
+    "18:"  // Oddments: Load (1, 4): Bit 2: Unset: Bit 1: Unset
+    "tbz x4, #0, 19f\n"
+    "ld1 { v30.b }[0], [x20]\n"
+    "19:"  // Oddments: Load (1, 4): Bit 2: End
+    "ushll v30.8h, v30.8b, #0x0\n"
+    "ldr x19, [x0, #0x60]\n"
+    "smlal v7.4s, v30.4h, v3.4h\n"
+    "smlal2 v6.4s, v30.8h, v3.8h\n"
+    "smlal v11.4s, v24.4h, v4.4h\n"
+    "smlal2 v21.4s, v24.8h, v4.8h\n"
+    "add x19, x19, x1\n"
+    "tbz x4, #2, 21f\n"
+    "ld1 { v27.s }[0], [x19], #0x4\n"
+    "tbz x4, #1, 20f\n"
+    "ld1 { v27.h }[2], [x19], #0x2\n"
+    "tbz x4, #0, 23f\n"
+    "ld1 { v27.b }[6], [x19]\n"
+    "b 23f\n"
+    "20:"  // Oddments: Load (0, 5): Bit 2: Bit 1: Unset
+    "tbz x4, #0, 23f\n"
+    "ld1 { v27.b }[4], [x19]\n"
+    "b 23f\n"
+    "21:"  // Oddments: Load (0, 5): Bit 2: Unset
+    "tbz x4, #1, 22f\n"
+    "ld1 { v27.h }[0], [x19], #0x2\n"
+    "tbz x4, #0, 23f\n"
+    "ld1 { v27.b }[2], [x19]\n"
+    "b 23f\n"
+    "22:"  // Oddments: Load (0, 5): Bit 2: Unset: Bit 1: Unset
+    "tbz x4, #0, 23f\n"
+    "ld1 { v27.b }[0], [x19]\n"
+    "23:"  // Oddments: Load (0, 5): Bit 2: End
+    "ushll v27.8h, v27.8b, #0x0\n"
+    "ldr d0, [x8, #0x28]\n"
+    "smlal v14.4s, v27.4h, v4.4h\n"
+    "smlal2 v10.4s, v27.8h, v4.8h\n"
+    "smlal v9.4s, v30.4h, v4.4h\n"
+    "smlal2 v8.4s, v30.8h, v4.8h\n"
+    "usubl v0.8h, v0.8b, v15.8b\n"
+    "ldr x24, [x0, #0x68]\n"
+    "smlal v7.4s, v26.4h, v4.4h\n"
+    "smlal2 v6.4s, v26.8h, v4.8h\n"
+    "add x24, x24, x1\n"
+    "smlal v11.4s, v29.4h, v0.4h\n"
+    "smlal2 v21.4s, v29.8h, v0.8h\n"
+    "smlal v14.4s, v28.4h, v0.4h\n"
+    "smlal2 v10.4s, v28.8h, v0.8h\n"
+    "smlal v9.4s, v22.4h, v0.4h\n"
+    "smlal2 v8.4s, v22.8h, v0.8h\n"
+    "tbz x4, #2, 25f\n"
+    "ld1 { v25.s }[0], [x24], #0x4\n"
+    "tbz x4, #1, 24f\n"
+    "ld1 { v25.h }[2], [x24], #0x2\n"
+    "tbz x4, #0, 27f\n"
+    "ld1 { v25.b }[6], [x24]\n"
+    "b 27f\n"
+    "24:"  // Oddments: Load (2, 1): Bit 2: Bit 1: Unset
+    "tbz x4, #0, 27f\n"
+    "ld1 { v25.b }[4], [x24]\n"
+    "b 27f\n"
+    "25:"  // Oddments: Load (2, 1): Bit 2: Unset
+    "tbz x4, #1, 26f\n"
+    "ld1 { v25.h }[0], [x24], #0x2\n"
+    "tbz x4, #0, 27f\n"
+    "ld1 { v25.b }[2], [x24]\n"
+    "b 27f\n"
+    "26:"  // Oddments: Load (2, 1): Bit 2: Unset: Bit 1: Unset
+    "tbz x4, #0, 27f\n"
+    "ld1 { v25.b }[0], [x24]\n"
+    "27:"  // Oddments: Load (2, 1): Bit 2: End
+    "ldr d1, [x8, #0x30]\n"
+    "ushll v25.8h, v25.8b, #0x0\n"
+    "usubl v1.8h, v1.8b, v15.8b\n"
+    "ldr x23, [x0, #0x70]\n"
+    "smlal v7.4s, v25.4h, v0.4h\n"
+    "smlal2 v6.4s, v25.8h, v0.8h\n"
+    "add x23, x23, x1\n"
+    "smlal v11.4s, v28.4h, v1.4h\n"
+    "smlal2 v21.4s, v28.8h, v1.8h\n"
+    "smlal v14.4s, v23.4h, v1.4h\n"
+    "smlal2 v10.4s, v23.8h, v1.8h\n"
+    "smlal v9.4s, v25.4h, v1.4h\n"
+    "smlal2 v8.4s, v25.8h, v1.8h\n"
+    "tbz x4, #2, 29f\n"
+    "ld1 { v24.s }[0], [x23], #0x4\n"
+    "tbz x4, #1, 28f\n"
+    "ld1 { v24.h }[2], [x23], #0x2\n"
+    "tbz x4, #0, 31f\n"
+    "ld1 { v24.b }[6], [x23]\n"
+    "b 31f\n"
+    "28:"  // Oddments: Load (2, 2): Bit 2: Bit 1: Unset
+    "tbz x4, #0, 31f\n"
+    "ld1 { v24.b }[4], [x23]\n"
+    "b 31f\n"
+    "29:"  // Oddments: Load (2, 2): Bit 2: Unset
+    "tbz x4, #1, 30f\n"
+    "ld1 { v24.h }[0], [x23], #0x2\n"
+    "tbz x4, #0, 31f\n"
+    "ld1 { v24.b }[2], [x23]\n"
+    "b 31f\n"
+    "30:"  // Oddments: Load (2, 2): Bit 2: Unset: Bit 1: Unset
+    "tbz x4, #0, 31f\n"
+    "ld1 { v24.b }[0], [x23]\n"
+    "31:"  // Oddments: Load (2, 2): Bit 2: End
+    "ldr d2, [x8, #0x38]\n"
+    "ushll v24.8h, v24.8b, #0x0\n"
+    "usubl v2.8h, v2.8b, v15.8b\n"
+    "ldr x26, [x0, #0x78]\n"
+    "smlal v7.4s, v24.4h, v1.4h\n"
+    "smlal2 v6.4s, v24.8h, v1.8h\n"
+    "add x26, x26, x1\n"
+    "smlal v11.4s, v23.4h, v2.4h\n"
+    "smlal2 v21.4s, v23.8h, v2.8h\n"
+    "smlal v14.4s, v31.4h, v2.4h\n"
+    "smlal2 v10.4s, v31.8h, v2.8h\n"
+    "smlal v9.4s, v24.4h, v2.4h\n"
+    "smlal2 v8.4s, v24.8h, v2.8h\n"
+    "tbz x4, #2, 33f\n"
+    "ld1 { v27.s }[0], [x26], #0x4\n"
+    "tbz x4, #1, 32f\n"
+    "ld1 { v27.h }[2], [x26], #0x2\n"
+    "tbz x4, #0, 35f\n"
+    "ld1 { v27.b }[6], [x26]\n"
+    "b 35f\n"
+    "32:"  // Oddments: Load (2, 3): Bit 2: Bit 1: Unset
+    "tbz x4, #0, 35f\n"
+    "ld1 { v27.b }[4], [x26]\n"
+    "b 35f\n"
+    "33:"  // Oddments: Load (2, 3): Bit 2: Unset
+    "tbz x4, #1, 34f\n"
+    "ld1 { v27.h }[0], [x26], #0x2\n"
+    "tbz x4, #0, 35f\n"
+    "ld1 { v27.b }[2], [x26]\n"
+    "b 35f\n"
+    "34:"  // Oddments: Load (2, 3): Bit 2: Unset: Bit 1: Unset
+    "tbz x4, #0, 35f\n"
+    "ld1 { v27.b }[0], [x26]\n"
+    "35:"  // Oddments: Load (2, 3): Bit 2: End
+    "ldr d3, [x8, #0x40]\n"
+    "ushll v27.8h, v27.8b, #0x0\n"
+    "usubl v3.8h, v3.8b, v15.8b\n"
+    "ldr x7, [x0, #0x80]\n"
+    "smlal v7.4s, v27.4h, v2.4h\n"
+    "smlal2 v6.4s, v27.8h, v2.8h\n"
+    "add x7, x7, x1\n"
+    "smlal v11.4s, v31.4h, v3.4h\n"
+    "smlal2 v21.4s, v31.8h, v3.8h\n"
+    "smlal v14.4s, v30.4h, v3.4h\n"
+    "smlal2 v10.4s, v30.8h, v3.8h\n"
+    "smlal v9.4s, v27.4h, v3.4h\n"
+    "smlal2 v8.4s, v27.8h, v3.8h\n"
+    "tbz x4, #2, 37f\n"
+    "ld1 { v23.s }[0], [x7], #0x4\n"
+    "tbz x4, #1, 36f\n"
+    "ld1 { v23.h }[2], [x7], #0x2\n"
+    "tbz x4, #0, 39f\n"
+    "ld1 { v23.b }[6], [x7]\n"
+    "b 39f\n"
+    "36:"  // Oddments: Load (2, 4): Bit 2: Bit 1: Unset
+    "tbz x4, #0, 39f\n"
+    "ld1 { v23.b }[4], [x7]\n"
+    "b 39f\n"
+    "37:"  // Oddments: Load (2, 4): Bit 2: Unset
+    "tbz x4, #1, 38f\n"
+    "ld1 { v23.h }[0], [x7], #0x2\n"
+    "tbz x4, #0, 39f\n"
+    "ld1 { v23.b }[2], [x7]\n"
+    "b 39f\n"
+    "38:"  // Oddments: Load (2, 4): Bit 2: Unset: Bit 1: Unset
+    "tbz x4, #0, 39f\n"
+    "ld1 { v23.b }[0], [x7]\n"
+    "39:"  // Oddments: Load (2, 4): Bit 2: End
+    "ldr d4, [x8, #0x48]\n"
+    "ushll v23.8h, v23.8b, #0x0\n"
+    "usubl v4.8h, v4.8b, v15.8b\n"
+    "ldr x22, [x0, #0x88]\n"
+    "smlal v7.4s, v23.4h, v3.4h\n"
+    "smlal2 v6.4s, v23.8h, v3.8h\n"
+    "add x22, x22, x1\n"
+    "smlal v11.4s, v30.4h, v4.4h\n"
+    "smlal2 v21.4s, v30.8h, v4.8h\n"
+    "smlal v14.4s, v26.4h, v4.4h\n"
+    "smlal2 v10.4s, v26.8h, v4.8h\n"
+    "smlal v9.4s, v23.4h, v4.4h\n"
+    "smlal2 v8.4s, v23.8h, v4.8h\n"
+    "tbz x4, #2, 41f\n"
+    "ld1 { v28.s }[0], [x22], #0x4\n"
+    "tbz x4, #1, 40f\n"
+    "ld1 { v28.h }[2], [x22], #0x2\n"
+    "tbz x4, #0, 43f\n"
+    "ld1 { v28.b }[6], [x22]\n"
+    "b 43f\n"
+    "40:"  // Oddments: Load (2, 5): Bit 2: Bit 1: Unset
+    "tbz x4, #0, 43f\n"
+    "ld1 { v28.b }[4], [x22]\n"
+    "b 43f\n"
+    "41:"  // Oddments: Load (2, 5): Bit 2: Unset
+    "tbz x4, #1, 42f\n"
+    "ld1 { v28.h }[0], [x22], #0x2\n"
+    "tbz x4, #0, 43f\n"
+    "ld1 { v28.b }[2], [x22]\n"
+    "b 43f\n"
+    "42:"  // Oddments: Load (2, 5): Bit 2: Unset: Bit 1: Unset
+    "tbz x4, #0, 43f\n"
+    "ld1 { v28.b }[0], [x22]\n"
+    "43:"  // Oddments: Load (2, 5): Bit 2: End
+    "ldr d0, [x8, #0x50]\n"
+    "ushll v28.8h, v28.8b, #0x0\n"
+    "usubl v0.8h, v0.8b, v15.8b\n"
+    "ldr x20, [x0, #0x90]\n"
+    "smlal v7.4s, v28.4h, v4.4h\n"
+    "smlal2 v6.4s, v28.8h, v4.8h\n"
+    "add x20, x20, x1\n"
+    "smlal v11.4s, v22.4h, v0.4h\n"
+    "smlal2 v21.4s, v22.8h, v0.8h\n"
+    "smlal v14.4s, v25.4h, v0.4h\n"
+    "smlal2 v10.4s, v25.8h, v0.8h\n"
+    "tbz x4, #2, 45f\n"
+    "ld1 { v31.s }[0], [x20], #0x4\n"
+    "tbz x4, #1, 44f\n"
+    "ld1 { v31.h }[2], [x20], #0x2\n"
+    "tbz x4, #0, 47f\n"
+    "ld1 { v31.b }[6], [x20]\n"
+    "b 47f\n"
+    "44:"  // Oddments: Load (3, 0): Bit 2: Bit 1: Unset
+    "tbz x4, #0, 47f\n"
+    "ld1 { v31.b }[4], [x20]\n"
+    "b 47f\n"
+    "45:"  // Oddments: Load (3, 0): Bit 2: Unset
+    "tbz x4, #1, 46f\n"
+    "ld1 { v31.h }[0], [x20], #0x2\n"
+    "tbz x4, #0, 47f\n"
+    "ld1 { v31.b }[2], [x20]\n"
+    "b 47f\n"
+    "46:"  // Oddments: Load (3, 0): Bit 2: Unset: Bit 1: Unset
+    "tbz x4, #0, 47f\n"
+    "ld1 { v31.b }[0], [x20]\n"
+    "47:"  // Oddments: Load (3, 0): Bit 2: End
+    "ushll v31.8h, v31.8b, #0x0\n"
+    "ldr x14, [x0, #0x98]\n"
+    "smlal v9.4s, v31.4h, v0.4h\n"
+    "smlal2 v8.4s, v31.8h, v0.8h\n"
+    "add x14, x14, x1\n"
+    "tbz x4, #2, 49f\n"
+    "ld1 { v30.s }[0], [x14], #0x4\n"
+    "tbz x4, #1, 48f\n"
+    "ld1 { v30.h }[2], [x14], #0x2\n"
+    "tbz x4, #0, 51f\n"
+    "ld1 { v30.b }[6], [x14]\n"
+    "b 51f\n"
+    "48:"  // Oddments: Load (3, 1): Bit 2: Bit 1: Unset
+    "tbz x4, #0, 51f\n"
+    "ld1 { v30.b }[4], [x14]\n"
+    "b 51f\n"
+    "49:"  // Oddments: Load (3, 1): Bit 2: Unset
+    "tbz x4, #1, 50f\n"
+    "ld1 { v30.h }[0], [x14], #0x2\n"
+    "tbz x4, #0, 51f\n"
+    "ld1 { v30.b }[2], [x14]\n"
+    "b 51f\n"
+    "50:"  // Oddments: Load (3, 1): Bit 2: Unset: Bit 1: Unset
+    "tbz x4, #0, 51f\n"
+    "ld1 { v30.b }[0], [x14]\n"
+    "51:"  // Oddments: Load (3, 1): Bit 2: End
+    "ldr d1, [x8, #0x58]\n"
+    "ushll v30.8h, v30.8b, #0x0\n"
+    "usubl v1.8h, v1.8b, v15.8b\n"
+    "ldr x19, [x0, #0xa0]\n"
+    "smlal v7.4s, v30.4h, v0.4h\n"
+    "smlal2 v6.4s, v30.8h, v0.8h\n"
+    "add x19, x19, x1\n"
+    "smlal v11.4s, v25.4h, v1.4h\n"
+    "smlal2 v21.4s, v25.8h, v1.8h\n"
+    "smlal v14.4s, v24.4h, v1.4h\n"
+    "smlal2 v10.4s, v24.8h, v1.8h\n"
+    "smlal v9.4s, v30.4h, v1.4h\n"
+    "smlal2 v8.4s, v30.8h, v1.8h\n"
+    "tbz x4, #2, 53f\n"
+    "ld1 { v26.s }[0], [x19], #0x4\n"
+    "tbz x4, #1, 52f\n"
+    "ld1 { v26.h }[2], [x19], #0x2\n"
+    "tbz x4, #0, 55f\n"
+    "ld1 { v26.b }[6], [x19]\n"
+    "b 55f\n"
+    "52:"  // Oddments: Load (3, 2): Bit 2: Bit 1: Unset
+    "tbz x4, #0, 55f\n"
+    "ld1 { v26.b }[4], [x19]\n"
+    "b 55f\n"
+    "53:"  // Oddments: Load (3, 2): Bit 2: Unset
+    "tbz x4, #1, 54f\n"
+    "ld1 { v26.h }[0], [x19], #0x2\n"
+    "tbz x4, #0, 55f\n"
+    "ld1 { v26.b }[2], [x19]\n"
+    "b 55f\n"
+    "54:"  // Oddments: Load (3, 2): Bit 2: Unset: Bit 1: Unset
+    "tbz x4, #0, 55f\n"
+    "ld1 { v26.b }[0], [x19]\n"
+    "55:"  // Oddments: Load (3, 2): Bit 2: End
+    "ldr d2, [x8, #0x60]\n"
+    "ushll v26.8h, v26.8b, #0x0\n"
+    "usubl v2.8h, v2.8b, v15.8b\n"
+    "ldr x13, [x0, #0xa8]\n"
+    "smlal v7.4s, v26.4h, v1.4h\n"
+    "smlal2 v6.4s, v26.8h, v1.8h\n"
+    "add x13, x13, x1\n"
+    "smlal v11.4s, v24.4h, v2.4h\n"
+    "smlal2 v21.4s, v24.8h, v2.8h\n"
+    "smlal v14.4s, v27.4h, v2.4h\n"
+    "smlal2 v10.4s, v27.8h, v2.8h\n"
+    "smlal v9.4s, v26.4h, v2.4h\n"
+    "smlal2 v8.4s, v26.8h, v2.8h\n"
+    "tbz x4, #2, 57f\n"
+    "ld1 { v25.s }[0], [x13], #0x4\n"
+    "tbz x4, #1, 56f\n"
+    "ld1 { v25.h }[2], [x13], #0x2\n"
+    "tbz x4, #0, 59f\n"
+    "ld1 { v25.b }[6], [x13]\n"
+    "b 59f\n"
+    "56:"  // Oddments: Load (3, 3): Bit 2: Bit 1: Unset
+    "tbz x4, #0, 59f\n"
+    "ld1 { v25.b }[4], [x13]\n"
+    "b 59f\n"
+    "57:"  // Oddments: Load (3, 3): Bit 2: Unset
+    "tbz x4, #1, 58f\n"
+    "ld1 { v25.h }[0], [x13], #0x2\n"
+    "tbz x4, #0, 59f\n"
+    "ld1 { v25.b }[2], [x13]\n"
+    "b 59f\n"
+    "58:"  // Oddments: Load (3, 3): Bit 2: Unset: Bit 1: Unset
+    "tbz x4, #0, 59f\n"
+    "ld1 { v25.b }[0], [x13]\n"
+    "59:"  // Oddments: Load (3, 3): Bit 2: End
+    "ldr d3, [x8, #0x68]\n"
+    "ushll v25.8h, v25.8b, #0x0\n"
+    "usubl v3.8h, v3.8b, v15.8b\n"
+    "ldr x12, [x0, #0xb0]\n"
+    "smlal v7.4s, v25.4h, v2.4h\n"
+    "smlal2 v6.4s, v25.8h, v2.8h\n"
+    "add x12, x12, x1\n"
+    "smlal v11.4s, v27.4h, v3.4h\n"
+    "smlal2 v21.4s, v27.8h, v3.8h\n"
+    "smlal v14.4s, v23.4h, v3.4h\n"
+    "smlal2 v10.4s, v23.8h, v3.8h\n"
+    "smlal v9.4s, v25.4h, v3.4h\n"
+    "smlal2 v8.4s, v25.8h, v3.8h\n"
+    "tbz x4, #2, 61f\n"
+    "ld1 { v24.s }[0], [x12], #0x4\n"
+    "tbz x4, #1, 60f\n"
+    "ld1 { v24.h }[2], [x12], #0x2\n"
+    "tbz x4, #0, 63f\n"
+    "ld1 { v24.b }[6], [x12]\n"
+    "b 63f\n"
+    "60:"  // Oddments: Load (3, 4): Bit 2: Bit 1: Unset
+    "tbz x4, #0, 63f\n"
+    "ld1 { v24.b }[4], [x12]\n"
+    "b 63f\n"
+    "61:"  // Oddments: Load (3, 4): Bit 2: Unset
+    "tbz x4, #1, 62f\n"
+    "ld1 { v24.h }[0], [x12], #0x2\n"
+    "tbz x4, #0, 63f\n"
+    "ld1 { v24.b }[2], [x12]\n"
+    "b 63f\n"
+    "62:"  // Oddments: Load (3, 4): Bit 2: Unset: Bit 1: Unset
+    "tbz x4, #0, 63f\n"
+    "ld1 { v24.b }[0], [x12]\n"
+    "63:"  // Oddments: Load (3, 4): Bit 2: End
+    "ldr d4, [x8, #0x70]\n"
+    "ushll v24.8h, v24.8b, #0x0\n"
+    "usubl v4.8h, v4.8b, v15.8b\n"
+    "ldr x11, [x0, #0xb8]\n"
+    "smlal v7.4s, v24.4h, v3.4h\n"
+    "smlal2 v6.4s, v24.8h, v3.8h\n"
+    "add x11, x11, x1\n"
+    "smlal v11.4s, v23.4h, v4.4h\n"
+    "smlal2 v21.4s, v23.8h, v4.8h\n"
+    "smlal v14.4s, v28.4h, v4.4h\n"
+    "smlal2 v10.4s, v28.8h, v4.8h\n"
+    "smlal v9.4s, v24.4h, v4.4h\n"
+    "smlal2 v8.4s, v24.8h, v4.8h\n"
+    "tbz x4, #2, 65f\n"
+    "ld1 { v22.s }[0], [x11], #0x4\n"
+    "tbz x4, #1, 64f\n"
+    "ld1 { v22.h }[2], [x11], #0x2\n"
+    "tbz x4, #0, 67f\n"
+    "ld1 { v22.b }[6], [x11]\n"
+    "b 67f\n"
+    "64:"  // Oddments: Load (3, 5): Bit 2: Bit 1: Unset
+    "tbz x4, #0, 67f\n"
+    "ld1 { v22.b }[4], [x11]\n"
+    "b 67f\n"
+    "65:"  // Oddments: Load (3, 5): Bit 2: Unset
+    "tbz x4, #1, 66f\n"
+    "ld1 { v22.h }[0], [x11], #0x2\n"
+    "tbz x4, #0, 67f\n"
+    "ld1 { v22.b }[2], [x11]\n"
+    "b 67f\n"
+    "66:"  // Oddments: Load (3, 5): Bit 2: Unset: Bit 1: Unset
+    "tbz x4, #0, 67f\n"
+    "ld1 { v22.b }[0], [x11]\n"
+    "67:"  // Oddments: Load (3, 5): Bit 2: End
+    "ldr d0, [x8, #0x78]\n"
+    "ushll v22.8h, v22.8b, #0x0\n"
+    "usubl v0.8h, v0.8b, v15.8b\n"
+    "ldr x10, [x0, #0xc0]\n"
+    "smlal v7.4s, v22.4h, v4.4h\n"
+    "smlal2 v6.4s, v22.8h, v4.8h\n"
+    "add x10, x10, x1\n"
+    "smlal v11.4s, v31.4h, v0.4h\n"
+    "smlal2 v21.4s, v31.8h, v0.8h\n"
+    "smlal v14.4s, v30.4h, v0.4h\n"
+    "smlal2 v10.4s, v30.8h, v0.8h\n"
+    "tbz x4, #2, 69f\n"
+    "ld1 { v27.s }[0], [x10], #0x4\n"
+    "tbz x4, #1, 68f\n"
+    "ld1 { v27.h }[2], [x10], #0x2\n"
+    "tbz x4, #0, 71f\n"
+    "ld1 { v27.b }[6], [x10]\n"
+    "b 71f\n"
+    "68:"  // Oddments: Load (4, 0): Bit 2: Bit 1: Unset
+    "tbz x4, #0, 71f\n"
+    "ld1 { v27.b }[4], [x10]\n"
+    "b 71f\n"
+    "69:"  // Oddments: Load (4, 0): Bit 2: Unset
+    "tbz x4, #1, 70f\n"
+    "ld1 { v27.h }[0], [x10], #0x2\n"
+    "tbz x4, #0, 71f\n"
+    "ld1 { v27.b }[2], [x10]\n"
+    "b 71f\n"
+    "70:"  // Oddments: Load (4, 0): Bit 2: Unset: Bit 1: Unset
+    "tbz x4, #0, 71f\n"
+    "ld1 { v27.b }[0], [x10]\n"
+    "71:"  // Oddments: Load (4, 0): Bit 2: End
+    "ushll v27.8h, v27.8b, #0x0\n"
+    "ldr x9, [x0, #0xc8]\n"
+    "smlal v9.4s, v27.4h, v0.4h\n"
+    "smlal2 v8.4s, v27.8h, v0.8h\n"
+    "add x9, x9, x1\n"
+    "tbz x4, #2, 73f\n"
+    "ld1 { v23.s }[0], [x9], #0x4\n"
+    "tbz x4, #1, 72f\n"
+    "ld1 { v23.h }[2], [x9], #0x2\n"
+    "tbz x4, #0, 75f\n"
+    "ld1 { v23.b }[6], [x9]\n"
+    "b 75f\n"
+    "72:"  // Oddments: Load (4, 1): Bit 2: Bit 1: Unset
+    "tbz x4, #0, 75f\n"
+    "ld1 { v23.b }[4], [x9]\n"
+    "b 75f\n"
+    "73:"  // Oddments: Load (4, 1): Bit 2: Unset
+    "tbz x4, #1, 74f\n"
+    "ld1 { v23.h }[0], [x9], #0x2\n"
+    "tbz x4, #0, 75f\n"
+    "ld1 { v23.b }[2], [x9]\n"
+    "b 75f\n"
+    "74:"  // Oddments: Load (4, 1): Bit 2: Unset: Bit 1: Unset
+    "tbz x4, #0, 75f\n"
+    "ld1 { v23.b }[0], [x9]\n"
+    "75:"  // Oddments: Load (4, 1): Bit 2: End
+    "ldr d1, [x8, #0x80]\n"
+    "ushll v23.8h, v23.8b, #0x0\n"
+    "usubl v1.8h, v1.8b, v15.8b\n"
+    "ldr x28, [x0, #0xd0]\n"
+    "smlal v7.4s, v23.4h, v0.4h\n"
+    "smlal2 v6.4s, v23.8h, v0.8h\n"
+    "add x28, x28, x1\n"
+    "smlal v11.4s, v30.4h, v1.4h\n"
+    "smlal2 v21.4s, v30.8h, v1.8h\n"
+    "smlal v14.4s, v26.4h, v1.4h\n"
+    "smlal2 v10.4s, v26.8h, v1.8h\n"
+    "smlal v9.4s, v23.4h, v1.4h\n"
+    "smlal2 v8.4s, v23.8h, v1.8h\n"
+    "tbz x4, #2, 77f\n"
+    "ld1 { v31.s }[0], [x28], #0x4\n"
+    "tbz x4, #1, 76f\n"
+    "ld1 { v31.h }[2], [x28], #0x2\n"
+    "tbz x4, #0, 79f\n"
+    "ld1 { v31.b }[6], [x28]\n"
+    "b 79f\n"
+    "76:"  // Oddments: Load (4, 2): Bit 2: Bit 1: Unset
+    "tbz x4, #0, 79f\n"
+    "ld1 { v31.b }[4], [x28]\n"
+    "b 79f\n"
+    "77:"  // Oddments: Load (4, 2): Bit 2: Unset
+    "tbz x4, #1, 78f\n"
+    "ld1 { v31.h }[0], [x28], #0x2\n"
+    "tbz x4, #0, 79f\n"
+    "ld1 { v31.b }[2], [x28]\n"
+    "b 79f\n"
+    "78:"  // Oddments: Load (4, 2): Bit 2: Unset: Bit 1: Unset
+    "tbz x4, #0, 79f\n"
+    "ld1 { v31.b }[0], [x28]\n"
+    "79:"  // Oddments: Load (4, 2): Bit 2: End
+    "ldr d2, [x8, #0x88]\n"
+    "ushll v31.8h, v31.8b, #0x0\n"
+    "usubl v2.8h, v2.8b, v15.8b\n"
+    "ldr x27, [x0, #0xd8]\n"
+    "smlal v7.4s, v31.4h, v1.4h\n"
+    "smlal2 v6.4s, v31.8h, v1.8h\n"
+    "add x27, x27, x1\n"
+    "smlal v11.4s, v26.4h, v2.4h\n"
+    "smlal2 v21.4s, v26.8h, v2.8h\n"
+    "smlal v14.4s, v25.4h, v2.4h\n"
+    "smlal2 v10.4s, v25.8h, v2.8h\n"
+    "smlal v9.4s, v31.4h, v2.4h\n"
+    "smlal2 v8.4s, v31.8h, v2.8h\n"
+    "tbz x4, #2, 81f\n"
+    "ld1 { v30.s }[0], [x27], #0x4\n"
+    "tbz x4, #1, 80f\n"
+    "ld1 { v30.h }[2], [x27], #0x2\n"
+    "tbz x4, #0, 83f\n"
+    "ld1 { v30.b }[6], [x27]\n"
+    "b 83f\n"
+    "80:"  // Oddments: Load (4, 3): Bit 2: Bit 1: Unset
+    "tbz x4, #0, 83f\n"
+    "ld1 { v30.b }[4], [x27]\n"
+    "b 83f\n"
+    "81:"  // Oddments: Load (4, 3): Bit 2: Unset
+    "tbz x4, #1, 82f\n"
+    "ld1 { v30.h }[0], [x27], #0x2\n"
+    "tbz x4, #0, 83f\n"
+    "ld1 { v30.b }[2], [x27]\n"
+    "b 83f\n"
+    "82:"  // Oddments: Load (4, 3): Bit 2: Unset: Bit 1: Unset
+    "tbz x4, #0, 83f\n"
+    "ld1 { v30.b }[0], [x27]\n"
+    "83:"  // Oddments: Load (4, 3): Bit 2: End
+    "ldr d3, [x8, #0x90]\n"
+    "ushll v30.8h, v30.8b, #0x0\n"
+    "usubl v3.8h, v3.8b, v15.8b\n"
+    "ldr x26, [x0, #0xe0]\n"
+    "smlal v7.4s, v30.4h, v2.4h\n"
+    "smlal2 v6.4s, v30.8h, v2.8h\n"
+    "add x26, x26, x1\n"
+    "smlal v11.4s, v25.4h, v3.4h\n"
+    "smlal2 v21.4s, v25.8h, v3.8h\n"
+    "smlal v14.4s, v24.4h, v3.4h\n"
+    "smlal2 v10.4s, v24.8h, v3.8h\n"
+    "smlal v9.4s, v30.4h, v3.4h\n"
+    "smlal2 v8.4s, v30.8h, v3.8h\n"
+    "tbz x4, #2, 85f\n"
+    "ld1 { v28.s }[0], [x26], #0x4\n"
+    "tbz x4, #1, 84f\n"
+    "ld1 { v28.h }[2], [x26], #0x2\n"
+    "tbz x4, #0, 87f\n"
+    "ld1 { v28.b }[6], [x26]\n"
+    "b 87f\n"
+    "84:"  // Oddments: Load (4, 4): Bit 2: Bit 1: Unset
+    "tbz x4, #0, 87f\n"
+    "ld1 { v28.b }[4], [x26]\n"
+    "b 87f\n"
+    "85:"  // Oddments: Load (4, 4): Bit 2: Unset
+    "tbz x4, #1, 86f\n"
+    "ld1 { v28.h }[0], [x26], #0x2\n"
+    "tbz x4, #0, 87f\n"
+    "ld1 { v28.b }[2], [x26]\n"
+    "b 87f\n"
+    "86:"  // Oddments: Load (4, 4): Bit 2: Unset: Bit 1: Unset
+    "tbz x4, #0, 87f\n"
+    "ld1 { v28.b }[0], [x26]\n"
+    "87:"  // Oddments: Load (4, 4): Bit 2: End
+    "ldr d4, [x8, #0x98]\n"
+    "ushll v28.8h, v28.8b, #0x0\n"
+    "usubl v4.8h, v4.8b, v15.8b\n"
+    "ldr x25, [x0, #0xe8]\n"
+    "smlal v7.4s, v28.4h, v3.4h\n"
+    "smlal2 v6.4s, v28.8h, v3.8h\n"
+    "add x25, x25, x1\n"
+    "smlal v11.4s, v24.4h, v4.4h\n"
+    "smlal2 v21.4s, v24.8h, v4.8h\n"
+    "smlal v14.4s, v22.4h, v4.4h\n"
+    "smlal2 v10.4s, v22.8h, v4.8h\n"
+    "smlal v9.4s, v28.4h, v4.4h\n"
+    "smlal2 v8.4s, v28.8h, v4.8h\n"
+    "tbz x4, #2, 89f\n"
+    "ld1 { v26.s }[0], [x25], #0x4\n"
+    "tbz x4, #1, 88f\n"
+    "ld1 { v26.h }[2], [x25], #0x2\n"
+    "tbz x4, #0, 91f\n"
+    "ld1 { v26.b }[6], [x25]\n"
+    "b 91f\n"
+    "88:"  // Oddments: Load (4, 5): Bit 2: Bit 1: Unset
+    "tbz x4, #0, 91f\n"
+    "ld1 { v26.b }[4], [x25]\n"
+    "b 91f\n"
+    "89:"  // Oddments: Load (4, 5): Bit 2: Unset
+    "tbz x4, #1, 90f\n"
+    "ld1 { v26.h }[0], [x25], #0x2\n"
+    "tbz x4, #0, 91f\n"
+    "ld1 { v26.b }[2], [x25]\n"
+    "b 91f\n"
+    "90:"  // Oddments: Load (4, 5): Bit 2: Unset: Bit 1: Unset
+    "tbz x4, #0, 91f\n"
+    "ld1 { v26.b }[0], [x25]\n"
+    "91:"  // Oddments: Load (4, 5): Bit 2: End
+    "ldr d0, [x8, #0xa0]\n"
+    "ushll v26.8h, v26.8b, #0x0\n"
+    "usubl v0.8h, v0.8b, v15.8b\n"
+    "ldr x24, [x0, #0xf0]\n"
+    "smlal v7.4s, v26.4h, v4.4h\n"
+    "smlal2 v6.4s, v26.8h, v4.8h\n"
+    "add x24, x24, x1\n"
+    "smlal v11.4s, v27.4h, v0.4h\n"
+    "smlal2 v21.4s, v27.8h, v0.8h\n"
+    "smlal v14.4s, v23.4h, v0.4h\n"
+    "smlal2 v10.4s, v23.8h, v0.8h\n"
+    "tbz x4, #2, 93f\n"
+    "ld1 { v25.s }[0], [x24], #0x4\n"
+    "tbz x4, #1, 92f\n"
+    "ld1 { v25.h }[2], [x24], #0x2\n"
+    "tbz x4, #0, 95f\n"
+    "ld1 { v25.b }[6], [x24]\n"
+    "b 95f\n"
+    "92:"  // Oddments: Load (5, 0): Bit 2: Bit 1: Unset
+    "tbz x4, #0, 95f\n"
+    "ld1 { v25.b }[4], [x24]\n"
+    "b 95f\n"
+    "93:"  // Oddments: Load (5, 0): Bit 2: Unset
+    "tbz x4, #1, 94f\n"
+    "ld1 { v25.h }[0], [x24], #0x2\n"
+    "tbz x4, #0, 95f\n"
+    "ld1 { v25.b }[2], [x24]\n"
+    "b 95f\n"
+    "94:"  // Oddments: Load (5, 0): Bit 2: Unset: Bit 1: Unset
+    "tbz x4, #0, 95f\n"
+    "ld1 { v25.b }[0], [x24]\n"
+    "95:"  // Oddments: Load (5, 0): Bit 2: End
+    "ushll v25.8h, v25.8b, #0x0\n"
+    "ldr x23, [x0, #0xf8]\n"
+    "smlal v9.4s, v25.4h, v0.4h\n"
+    "smlal2 v8.4s, v25.8h, v0.8h\n"
+    "add x23, x23, x1\n"
+    "tbz x4, #2, 97f\n"
+    "ld1 { v24.s }[0], [x23], #0x4\n"
+    "tbz x4, #1, 96f\n"
+    "ld1 { v24.h }[2], [x23], #0x2\n"
+    "tbz x4, #0, 99f\n"
+    "ld1 { v24.b }[6], [x23]\n"
+    "b 99f\n"
+    "96:"  // Oddments: Load (5, 1): Bit 2: Bit 1: Unset
+    "tbz x4, #0, 99f\n"
+    "ld1 { v24.b }[4], [x23]\n"
+    "b 99f\n"
+    "97:"  // Oddments: Load (5, 1): Bit 2: Unset
+    "tbz x4, #1, 98f\n"
+    "ld1 { v24.h }[0], [x23], #0x2\n"
+    "tbz x4, #0, 99f\n"
+    "ld1 { v24.b }[2], [x23]\n"
+    "b 99f\n"
+    "98:"  // Oddments: Load (5, 1): Bit 2: Unset: Bit 1: Unset
+    "tbz x4, #0, 99f\n"
+    "ld1 { v24.b }[0], [x23]\n"
+    "99:"  // Oddments: Load (5, 1): Bit 2: End
+    "ldr d1, [x8, #0xa8]\n"
+    "ushll v24.8h, v24.8b, #0x0\n"
+    "usubl v1.8h, v1.8b, v15.8b\n"
+    "ldr x22, [x0, #0x100]\n"
+    "smlal v7.4s, v24.4h, v0.4h\n"
+    "smlal2 v6.4s, v24.8h, v0.8h\n"
+    "add x22, x22, x1\n"
+    "smlal v11.4s, v23.4h, v1.4h\n"
+    "smlal2 v21.4s, v23.8h, v1.8h\n"
+    "smlal v14.4s, v31.4h, v1.4h\n"
+    "smlal2 v10.4s, v31.8h, v1.8h\n"
+    "smlal v9.4s, v24.4h, v1.4h\n"
+    "smlal2 v8.4s, v24.8h, v1.8h\n"
+    "tbz x4, #2, 101f\n"
+    "ld1 { v27.s }[0], [x22], #0x4\n"
+    "tbz x4, #1, 100f\n"
+    "ld1 { v27.h }[2], [x22], #0x2\n"
+    "tbz x4, #0, 103f\n"
+    "ld1 { v27.b }[6], [x22]\n"
+    "b 103f\n"
+    "100:"  // Oddments: Load (5, 2): Bit 2: Bit 1: Unset
+    "tbz x4, #0, 103f\n"
+    "ld1 { v27.b }[4], [x22]\n"
+    "b 103f\n"
+    "101:"  // Oddments: Load (5, 2): Bit 2: Unset
+    "tbz x4, #1, 102f\n"
+    "ld1 { v27.h }[0], [x22], #0x2\n"
+    "tbz x4, #0, 103f\n"
+    "ld1 { v27.b }[2], [x22]\n"
+    "b 103f\n"
+    "102:"  // Oddments: Load (5, 2): Bit 2: Unset: Bit 1: Unset
+    "tbz x4, #0, 103f\n"
+    "ld1 { v27.b }[0], [x22]\n"
+    "103:"  // Oddments: Load (5, 2): Bit 2: End
+    "ldr d2, [x8, #0xb0]\n"
+    "ushll v27.8h, v27.8b, #0x0\n"
+    "usubl v2.8h, v2.8b, v15.8b\n"
+    "ldr x7, [x0, #0x108]\n"
+    "smlal v7.4s, v27.4h, v1.4h\n"
+    "smlal2 v6.4s, v27.8h, v1.8h\n"
+    "add x7, x7, x1\n"
+    "smlal v11.4s, v31.4h, v2.4h\n"
+    "smlal2 v21.4s, v31.8h, v2.8h\n"
+    "smlal v14.4s, v30.4h, v2.4h\n"
+    "smlal2 v10.4s, v30.8h, v2.8h\n"
+    "smlal v9.4s, v27.4h, v2.4h\n"
+    "smlal2 v8.4s, v27.8h, v2.8h\n"
+    "tbz x4, #2, 105f\n"
+    "ld1 { v25.s }[0], [x7], #0x4\n"
+    "tbz x4, #1, 104f\n"
+    "ld1 { v25.h }[2], [x7], #0x2\n"
+    "tbz x4, #0, 107f\n"
+    "ld1 { v25.b }[6], [x7]\n"
+    "b 107f\n"
+    "104:"  // Oddments: Load (5, 3): Bit 2: Bit 1: Unset
+    "tbz x4, #0, 107f\n"
+    "ld1 { v25.b }[4], [x7]\n"
+    "b 107f\n"
+    "105:"  // Oddments: Load (5, 3): Bit 2: Unset
+    "tbz x4, #1, 106f\n"
+    "ld1 { v25.h }[0], [x7], #0x2\n"
+    "tbz x4, #0, 107f\n"
+    "ld1 { v25.b }[2], [x7]\n"
+    "b 107f\n"
+    "106:"  // Oddments: Load (5, 3): Bit 2: Unset: Bit 1: Unset
+    "tbz x4, #0, 107f\n"
+    "ld1 { v25.b }[0], [x7]\n"
+    "107:"  // Oddments: Load (5, 3): Bit 2: End
+    "ldr d3, [x8, #0xb8]\n"
+    "ushll v25.8h, v25.8b, #0x0\n"
+    "usubl v3.8h, v3.8b, v15.8b\n"
+    "ldr x20, [x0, #0x110]\n"
+    "smlal v7.4s, v25.4h, v2.4h\n"
+    "smlal2 v6.4s, v25.8h, v2.8h\n"
+    "add x20, x20, x1\n"
+    "smlal v11.4s, v30.4h, v3.4h\n"
+    "smlal2 v21.4s, v30.8h, v3.8h\n"
+    "smlal v14.4s, v28.4h, v3.4h\n"
+    "smlal2 v10.4s, v28.8h, v3.8h\n"
+    "smlal v9.4s, v25.4h, v3.4h\n"
+    "smlal2 v8.4s, v25.8h, v3.8h\n"
+    "tbz x4, #2, 109f\n"
+    "ld1 { v24.s }[0], [x20], #0x4\n"
+    "tbz x4, #1, 108f\n"
+    "ld1 { v24.h }[2], [x20], #0x2\n"
+    "tbz x4, #0, 111f\n"
+    "ld1 { v24.b }[6], [x20]\n"
+    "b 111f\n"
+    "108:"  // Oddments: Load (5, 4): Bit 2: Bit 1: Unset
+    "tbz x4, #0, 111f\n"
+    "ld1 { v24.b }[4], [x20]\n"
+    "b 111f\n"
+    "109:"  // Oddments: Load (5, 4): Bit 2: Unset
+    "tbz x4, #1, 110f\n"
+    "ld1 { v24.h }[0], [x20], #0x2\n"
+    "tbz x4, #0, 111f\n"
+    "ld1 { v24.b }[2], [x20]\n"
+    "b 111f\n"
+    "110:"  // Oddments: Load (5, 4): Bit 2: Unset: Bit 1: Unset
+    "tbz x4, #0, 111f\n"
+    "ld1 { v24.b }[0], [x20]\n"
+    "111:"  // Oddments: Load (5, 4): Bit 2: End
+    "ldr d4, [x8, #0xc0]\n"
+    "ushll v24.8h, v24.8b, #0x0\n"
+    "usubl v4.8h, v4.8b, v15.8b\n"
+    "ldr x19, [x0, #0x118]\n"
+    "smlal v7.4s, v24.4h, v3.4h\n"
+    "smlal2 v6.4s, v24.8h, v3.8h\n"
+    "add x19, x19, x1\n"
+    "smlal v11.4s, v28.4h, v4.4h\n"
+    "smlal2 v21.4s, v28.8h, v4.8h\n"
+    "smlal v14.4s, v26.4h, v4.4h\n"
+    "smlal2 v10.4s, v26.8h, v4.8h\n"
+    "smlal v9.4s, v24.4h, v4.4h\n"
+    "smlal2 v8.4s, v24.8h, v4.8h\n"
+    "tbz x4, #2, 113f\n"
+    "ld1 { v27.s }[0], [x19], #0x4\n"
+    "tbz x4, #1, 112f\n"
+    "ld1 { v27.h }[2], [x19], #0x2\n"
+    "tbz x4, #0, 115f\n"
+    "ld1 { v27.b }[6], [x19]\n"
+    "b 115f\n"
+    "112:"  // Oddments: Load (5, 5): Bit 2: Bit 1: Unset
+    "tbz x4, #0, 115f\n"
+    "ld1 { v27.b }[4], [x19]\n"
+    "b 115f\n"
+    "113:"  // Oddments: Load (5, 5): Bit 2: Unset
+    "tbz x4, #1, 114f\n"
+    "ld1 { v27.h }[0], [x19], #0x2\n"
+    "tbz x4, #0, 115f\n"
+    "ld1 { v27.b }[2], [x19]\n"
+    "b 115f\n"
+    "114:"  // Oddments: Load (5, 5): Bit 2: Unset: Bit 1: Unset
+    "tbz x4, #0, 115f\n"
+    "ld1 { v27.b }[0], [x19]\n"
+    "115:"  // Oddments: Load (5, 5): Bit 2: End
+    "ushll v27.8h, v27.8b, #0x0\n"
+    "smlal v7.4s, v27.4h, v4.4h\n"
+    "smlal2 v6.4s, v27.8h, v4.8h\n"
+    "tbz x4, #2, 117f\n"
+    "ld1 { v17.4s }, [x5], #0x10\n"
+    "ld1 { v5.4s }, [x6], #0x10\n"
+    "tbz x4, #1, 116f\n"
+    "ld1 { v18.d }[0], [x5], #0x8\n"
+    "ld1 { v29.d }[0], [x6], #0x8\n"
+    "tbz x4, #0, 119f\n"
+    "ld1 { v18.s }[2], [x5]\n"
+    "ld1 { v29.s }[2], [x6]\n"
+    "b 119f\n"
+    "116:"  // Oddments: Load requant params: Bit 2: Bit 1: Unset
+    "tbz x4, #0, 119f\n"
+    "ld1 { v18.s }[0], [x5]\n"
+    "ld1 { v29.s }[0], [x6]\n"
+    "b 119f\n"
+    "117:"  // Oddments: Load requant params: Bit 2: Unset
+    "tbz x4, #1, 118f\n"
+    "ld1 { v17.d }[0], [x5], #0x8\n"
+    "ld1 { v5.d }[0], [x6], #0x8\n"
+    "tbz x4, #0, 119f\n"
+    "ld1 { v17.s }[2], [x5]\n"
+    "ld1 { v5.s }[2], [x6]\n"
+    "b 119f\n"
+    "118:"  // Oddments: Load requant params: Bit 2: Unset: Bit 1: Unset
+    "tbz x4, #0, 119f\n"
+    "ld1 { v17.s }[0], [x5]\n"
+    "ld1 { v5.s }[0], [x6]\n"
+    "119:"  // Oddments: Load requant params: Bit 2: End
+    "sqdmulh v11.4s, v11.4s, v17.4s\n"
+    "sqdmulh v14.4s, v14.4s, v17.4s\n"
+    "add x21, x21, x2\n"
+    "add x15, x15, x2\n"
+    "sqdmulh v9.4s, v9.4s, v17.4s\n"
+    "sqdmulh v7.4s, v7.4s, v17.4s\n"
+    "add x17, x17, x2\n"
+    "add x16, x16, x2\n"
+    "and v23.16b, v11.16b, v5.16b\n"
+    "sqdmulh v21.4s, v21.4s, v18.4s\n"
+    "and v22.16b, v14.16b, v5.16b\n"
+    "sqdmulh v10.4s, v10.4s, v18.4s\n"
+    "and v17.16b, v9.16b, v5.16b\n"
+    "sqdmulh v8.4s, v8.4s, v18.4s\n"
+    "and v20.16b, v7.16b, v5.16b\n"
+    "sqdmulh v6.4s, v6.4s, v18.4s\n"
+    "sshr v23.4s, v23.4s, #0x1f\n"
+    "and v19.16b, v21.16b, v29.16b\n"
+    "sshr v22.4s, v22.4s, #0x1f\n"
+    "and v18.16b, v10.16b, v29.16b\n"
+    "sshr v17.4s, v17.4s, #0x1f\n"
+    "and v26.16b, v8.16b, v29.16b\n"
+    "sshr v20.4s, v20.4s, #0x1f\n"
+    "and v4.16b, v6.16b, v29.16b\n"
+    "sqadd v11.4s, v11.4s, v23.4s\n"
+    "sshr v19.4s, v19.4s, #0x1f\n"
+    "sqadd v14.4s, v14.4s, v22.4s\n"
+    "sshr v18.4s, v18.4s, #0x1f\n"
+    "sqadd v9.4s, v9.4s, v17.4s\n"
+    "sshr v26.4s, v26.4s, #0x1f\n"
+    "sqadd v7.4s, v7.4s, v20.4s\n"
+    "sshr v4.4s, v4.4s, #0x1f\n"
+    "srshl v11.4s, v11.4s, v5.4s\n"
+    "sqadd v21.4s, v21.4s, v19.4s\n"
+    "srshl v14.4s, v14.4s, v5.4s\n"
+    "sqadd v10.4s, v10.4s, v18.4s\n"
+    "srshl v9.4s, v9.4s, v5.4s\n"
+    "sqadd v8.4s, v8.4s, v26.4s\n"
+    "srshl v7.4s, v7.4s, v5.4s\n"
+    "sqadd v6.4s, v6.4s, v4.4s\n"
+    "srshl v21.4s, v21.4s, v29.4s\n"
+    "sqxtn v11.4h, v11.4s\n"
+    "srshl v10.4s, v10.4s, v29.4s\n"
+    "sqxtn v14.4h, v14.4s\n"
+    "srshl v8.4s, v8.4s, v29.4s\n"
+    "sqxtn v9.4h, v9.4s\n"
+    "srshl v6.4s, v6.4s, v29.4s\n"
+    "sqxtn v7.4h, v7.4s\n"
+    "sqxtn2 v11.8h, v21.4s\n"
+    "sqxtn2 v14.8h, v10.4s\n"
+    "sqxtn2 v9.8h, v8.4s\n"
+    "sqxtn2 v7.8h, v6.4s\n"
+    "sqadd v11.8h, v11.8h, v16.8h\n"
+    "sqadd v14.8h, v14.8h, v16.8h\n"
+    "sqadd v9.8h, v9.8h, v16.8h\n"
+    "sqadd v7.8h, v7.8h, v16.8h\n"
+    "smax v11.8h, v11.8h, v12.8h\n"
+    "smax v14.8h, v14.8h, v12.8h\n"
+    "smax v9.8h, v9.8h, v12.8h\n"
+    "smax v7.8h, v7.8h, v12.8h\n"
+    "smin v11.8h, v11.8h, v13.8h\n"
+    "smin v14.8h, v14.8h, v13.8h\n"
+    "smin v9.8h, v9.8h, v13.8h\n"
+    "smin v7.8h, v7.8h, v13.8h\n"
+    "uzp1 v11.16b, v11.16b, v11.16b\n"
+    "uzp1 v14.16b, v14.16b, v14.16b\n"
+    "uzp1 v9.16b, v9.16b, v9.16b\n"
+    "uzp1 v7.16b, v7.16b, v7.16b\n"
+    "tbz x4, #2, 121f\n"
+    "st1 { v11.s }[0], [x21], #0x4\n"
+    "st1 { v14.s }[0], [x15], #0x4\n"
+    "st1 { v9.s }[0], [x17], #0x4\n"
+    "st1 { v7.s }[0], [x16], #0x4\n"
+    "tbz x4, #1, 120f\n"
+    "st1 { v11.h }[2], [x21], #0x2\n"
+    "st1 { v14.h }[2], [x15], #0x2\n"
+    "st1 { v9.h }[2], [x17], #0x2\n"
+    "st1 { v7.h }[2], [x16], #0x2\n"
+    "tbz x4, #0, 123f\n"
+    "st1 { v11.b }[6], [x21], #0x1\n"
+    "st1 { v14.b }[6], [x15], #0x1\n"
+    "st1 { v9.b }[6], [x17], #0x1\n"
+    "st1 { v7.b }[6], [x16], #0x1\n"
+    "b 123f\n"
+    "120:"  // Oddments: Bit 2: Bit 1: Unset
+    "tbz x4, #0, 123f\n"
+    "st1 { v11.b }[4], [x21], #0x1\n"
+    "st1 { v14.b }[4], [x15], #0x1\n"
+    "st1 { v9.b }[4], [x17], #0x1\n"
+    "st1 { v7.b }[4], [x16], #0x1\n"
+    "b 123f\n"
+    "121:"  // Oddments: Bit 2: Unset
+    "tbz x4, #1, 122f\n"
+    "st1 { v11.h }[0], [x21], #0x2\n"
+    "st1 { v14.h }[0], [x15], #0x2\n"
+    "st1 { v9.h }[0], [x17], #0x2\n"
+    "st1 { v7.h }[0], [x16], #0x2\n"
+    "tbz x4, #0, 123f\n"
+    "st1 { v11.b }[2], [x21], #0x1\n"
+    "st1 { v14.b }[2], [x15], #0x1\n"
+    "st1 { v9.b }[2], [x17], #0x1\n"
+    "st1 { v7.b }[2], [x16], #0x1\n"
+    "b 123f\n"
+    "122:"  // Oddments: Bit 2: Unset: Bit 1: Unset
+    "tbz x4, #0, 123f\n"
+    "st1 { v11.b }[0], [x21], #0x1\n"
+    "st1 { v14.b }[0], [x15], #0x1\n"
+    "st1 { v9.b }[0], [x17], #0x1\n"
+    "st1 { v7.b }[0], [x16], #0x1\n"
+    "123:"  // Oddments: Bit 2: End
+    "124:"  // End
+    :
+    : [offsetof_Params_bias] "I" (offsetof(Params, bias)), [offsetof_Params_inptrs] "I" (offsetof(Params, inptrs)), [offsetof_Params_n_channels] "I" (offsetof(Params, n_channels)), [offsetof_Params_outptrs] "I" (offsetof(Params, outptrs)), [offsetof_Params_requant] "I" (offsetof(Params, requant)), [offsetof_Params_requant_muls] "I" (offsetof(Params, requant_muls)), [offsetof_Params_requant_shifts] "I" (offsetof(Params, requant_shifts)), [offsetof_Params_weights] "I" (offsetof(Params, weights)), [offsetof_Requantize32_b_offset] "I" (offsetof(arm_gemm::Requantize32, b_offset)), [offsetof_Requantize32_c_offset] "I" (offsetof(arm_gemm::Requantize32, c_offset)), [offsetof_Requantize32_maxval] "I" (offsetof(arm_gemm::Requantize32, maxval)), [offsetof_Requantize32_minval] "I" (offsetof(arm_gemm::Requantize32, minval)), [params] "r" (&params)
+    : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17", "x19", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"
+  );
+}
+
+}  // namespace depthwise
+}  // namespace arm_conv
+
+#endif  // defined(__aarch64__)
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8s8u8q_nhwc_3x3_s1_output2x2_mla_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8s8u8q_nhwc_3x3_s1_output2x2_mla_depthfirst.hpp
index 1bacb5f..281511a 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8s8u8q_nhwc_3x3_s1_output2x2_mla_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8s8u8q_nhwc_3x3_s1_output2x2_mla_depthfirst.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -36,37 +36,24 @@
 
 void a64_u8s8u8q_nhwc_3x3_s1_output2x2_mla_depthfirst_impl(unsigned int, const uint8_t *const *, const int8_t *, const int32_t *, const arm_gemm::Requantize32 &, const int32_t *, const int32_t *, uint8_t *const *);
 
-struct a64_u8s8u8q_nhwc_3x3_s1_output2x2_mla_depthfirst
+class a64_u8s8u8q_nhwc_3x3_s1_output2x2_mla_depthfirst : public DepthwiseDepthfirstStrategy<uint8_t, int8_t, uint8_t, int32_t>
 {
-  typedef int32_t bias_type;
-  typedef uint8_t input_type;
-  typedef int8_t weight_type;
-  typedef uint8_t return_type;
+  using Parent = DepthwiseDepthfirstStrategy<uint8_t, int8_t, uint8_t, int32_t>;
 
-  constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::None;
-
-  typedef void (*kern_type)(unsigned int, const uint8_t *const *, const int8_t *, const int32_t *, const arm_gemm::Requantize32 &, const int32_t *, const int32_t *, uint8_t *const *);
-  typedef void (*parameter_packing_fn)(unsigned int, void *, const int8_t *, size_t, size_t);
-  typedef size_t (*parameter_sizing_fn)(const DepthwiseArgs &);
-
+  public:
   constexpr static unsigned int kernel_rows = 3;
   constexpr static unsigned int kernel_cols = 3;
 
   constexpr static unsigned int stride_rows = 1;
   constexpr static unsigned int stride_cols = 1;
 
-  constexpr static unsigned int output_rows = 2;
-  constexpr static unsigned int output_cols = 2;
+  a64_u8s8u8q_nhwc_3x3_s1_output2x2_mla_depthfirst(const CPUInfo *) : Parent(2, 2, 3, 3, 1, 1) {}
 
-  constexpr static unsigned int input_rows = 4;
-  constexpr static unsigned int input_cols = 4;
+  arm_gemm::VLType get_vl_type(void) const override { return arm_gemm::VLType::None; }
 
-  constexpr static parameter_packing_fn pack_parameters = interleave_a64_s8q_3x3_mla::pack_parameters;
-  constexpr static parameter_sizing_fn get_packed_size = interleave_a64_s8q_3x3_mla::get_packed_size;
-
-  kern_type kernel = a64_u8s8u8q_nhwc_3x3_s1_output2x2_mla_depthfirst_impl;
-
-  a64_u8s8u8q_nhwc_3x3_s1_output2x2_mla_depthfirst(const CPUInfo *) {}
+  Parent::KernelType kernel = a64_u8s8u8q_nhwc_3x3_s1_output2x2_mla_depthfirst_impl;
+  Parent::KernelType get_kernel(void) const override { return kernel; }
+  unsigned int get_accumulator_depth_vl(void) const override { return 2; }
 };
 
 }  // namespace depthwise
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8s8u8q_nhwc_3x3_s1_output2x2_mla_depthfirst/generic.cpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8s8u8q_nhwc_3x3_s1_output2x2_mla_depthfirst/generic.cpp
index 8cbbfae..22f9574 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8s8u8q_nhwc_3x3_s1_output2x2_mla_depthfirst/generic.cpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8s8u8q_nhwc_3x3_s1_output2x2_mla_depthfirst/generic.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -46,7 +46,7 @@
   struct Params
   {
     long unsigned int n_channels;
-    const int8_t *weights;
+    const void *weights;
     const int32_t *bias;
     const arm_gemm::Requantize32 *requant;
     const int32_t *const requant_muls;
@@ -57,7 +57,7 @@
     Params(
       long unsigned int n_channels,
       const uint8_t *const *inptrs_raw,
-      const int8_t *const weights,
+      const void *const weights,
       const int32_t *const bias,
       const arm_gemm::Requantize32 &qp,
       const int32_t *const requant_muls,
@@ -91,505 +91,489 @@
                       requant_muls, requant_shifts, outptrs);
 
   __asm__ __volatile__(
+    "ldr x19, [%x[params], %[offsetof_Params_requant]]\n"
     "ldr x8, [%x[params], %[offsetof_Params_n_channels]]\n"
-    "mov x17, #0x0\n"
-    "ldr x16, [%x[params], %[offsetof_Params_weights]]\n"
+    "add x24, x19, %[offsetof_Requantize32_a_offset]\n"
+    "add x23, x19, %[offsetof_Requantize32_b_offset]\n"
+    "ldr x22, [%x[params], %[offsetof_Params_outptrs]]\n"
+    "add x21, x19, %[offsetof_Requantize32_c_offset]\n"
+    "add x20, x19, %[offsetof_Requantize32_minval]\n"
+    "ldr x17, [%x[params], %[offsetof_Params_weights]]\n"
+    "add x19, x19, %[offsetof_Requantize32_maxval]\n"
+    "ld1r { v22.16b }, [x24]\n"
+    "ld1r { v12.16b }, [x23]\n"
+    "lsr x16, x8, #0x3\n"
+    "ld1r { v14.8h }, [x21]\n"
+    "ld1r { v17.8h }, [x20]\n"
     "mov x15, #0x0\n"
-    "ldr x22, [%x[params], %[offsetof_Params_requant]]\n"
-    "add x14, %x[params], %[offsetof_Params_inptrs]\n"
+    "mov x14, #0x0\n"
+    "ld1r { v15.8h }, [x19]\n"
     "ldr x13, [%x[params], %[offsetof_Params_requant_muls]]\n"
-    "lsr x12, x8, #0x3\n"
+    "add x12, %x[params], %[offsetof_Params_inptrs]\n"
     "ldr x11, [%x[params], %[offsetof_Params_requant_shifts]]\n"
-    "add x19, x22, %[offsetof_Requantize32_a_offset]\n"
-    "ldr x21, [%x[params], %[offsetof_Params_outptrs]]\n"
-    "add x20, x22, %[offsetof_Requantize32_b_offset]\n"
-    "ld1r { v21.16b }, [x19]\n"
-    "add x19, x22, %[offsetof_Requantize32_c_offset]\n"
-    "ld1r { v17.16b }, [x20]\n"
-    "add x20, x22, %[offsetof_Requantize32_minval]\n"
-    "ld1r { v13.4s }, [x19]\n"
-    "add x19, x22, %[offsetof_Requantize32_maxval]\n"
-    "ld1r { v15.4s }, [x20]\n"
-    "ld1r { v14.4s }, [x19]\n"
-    "ldp x10, x9, [x21, #0x0]\n"
-    "ldp x28, x27, [x21, #0x10]\n"
-    "cbz x12, 3f\n"
-    "subs x12, x12, #0x1\n"
+    "ldp x10, x9, [x22, #0x0]\n"
+    "ldp x28, x27, [x22, #0x10]\n"
+    "cbz x16, 3f\n"
     "ldr x19, [%x[params], %[offsetof_Params_bias]]\n"
-    "ldr q11, [x19, #0x0]\n"
-    "mov v23.16b, v11.16b\n"
+    "ldr q13, [x19, #0x0]\n"
+    "subs x16, x16, #0x1\n"
+    "mov v19.16b, v13.16b\n"
     "ldr q26, [x19, #0x10]\n"
     "add x19, x19, #0x20\n"
-    "mov v12.16b, v11.16b\n"
     "str x19, [%x[params], %[offsetof_Params_bias]]\n"
-    "mov v24.16b, v11.16b\n"
-    "ldr d0, [x16, #0x0]\n"
-    "ldr d1, [x16, #0x8]\n"
-    "mov v9.16b, v26.16b\n"
-    "ldr d2, [x16, #0x10]\n"
-    "mov v22.16b, v26.16b\n"
-    "ldr d3, [x16, #0x18]\n"
-    "mov v10.16b, v26.16b\n"
-    "ldr d4, [x16, #0x20]\n"
-    "ssubl v0.8h, v0.8b, v17.8b\n"
-    "ldr d5, [x16, #0x28]\n"
-    "ssubl v1.8h, v1.8b, v17.8b\n"
-    "ldr d6, [x16, #0x30]\n"
-    "ssubl v2.8h, v2.8b, v17.8b\n"
-    "ldr d7, [x16, #0x38]\n"
-    "ssubl v3.8h, v3.8b, v17.8b\n"
-    "ldr d8, [x16, #0x40]\n"
-    "ssubl v4.8h, v4.8b, v17.8b\n"
-    "ldp x23, x22, [x14, #0x0]\n"
-    "ssubl v5.8h, v5.8b, v17.8b\n"
-    "ldp x21, x20, [x14, #0x10]\n"
-    "ssubl v6.8h, v6.8b, v17.8b\n"
-    "ssubl v7.8h, v7.8b, v17.8b\n"
-    "ldr x19, [x14, #0x20]\n"
-    "ssubl v8.8h, v8.8b, v17.8b\n"
-    "ldr d31, [x23, x17]\n"
-    "usubl v31.8h, v31.8b, v21.8b\n"
-    "ldr d30, [x22, x17]\n"
-    "ldr d29, [x21, x17]\n"
-    "usubl v30.8h, v30.8b, v21.8b\n"
-    "ldr d28, [x20, x17]\n"
-    "usubl v29.8h, v29.8b, v21.8b\n"
-    "ldr d27, [x19, x17]\n"
-    "usubl v28.8h, v28.8b, v21.8b\n"
-    "usubl v27.8h, v27.8b, v21.8b\n"
+    "ldr d0, [x17, #0x0]\n"
+    "ldr d1, [x17, #0x8]\n"
+    "ldr d2, [x17, #0x10]\n"
+    "mov v11.16b, v26.16b\n"
+    "mov v18.16b, v13.16b\n"
+    "ldr d3, [x17, #0x18]\n"
+    "ldr d4, [x17, #0x20]\n"
+    "mov v24.16b, v26.16b\n"
+    "mov v9.16b, v13.16b\n"
+    "ldr d5, [x17, #0x28]\n"
+    "ldr d6, [x17, #0x30]\n"
+    "mov v23.16b, v26.16b\n"
+    "ssubl v0.8h, v0.8b, v12.8b\n"
+    "ldr d7, [x17, #0x38]\n"
+    "ldr d8, [x17, #0x40]\n"
+    "ssubl v1.8h, v1.8b, v12.8b\n"
+    "ssubl v2.8h, v2.8b, v12.8b\n"
+    "ldp x23, x22, [x12, #0x0]\n"
+    "ldp x21, x20, [x12, #0x10]\n"
+    "ssubl v3.8h, v3.8b, v12.8b\n"
+    "ssubl v4.8h, v4.8b, v12.8b\n"
+    "ldr x19, [x12, #0x20]\n"
+    "ldr d31, [x23, x15]\n"
+    "ssubl v5.8h, v5.8b, v12.8b\n"
+    "ssubl v6.8h, v6.8b, v12.8b\n"
+    "ldr d30, [x22, x15]\n"
+    "ldr d29, [x21, x15]\n"
+    "ssubl v7.8h, v7.8b, v12.8b\n"
+    "ssubl v8.8h, v8.8b, v12.8b\n"
+    "ldr d28, [x20, x15]\n"
+    "ldr d27, [x19, x15]\n"
+    "usubl v31.8h, v31.8b, v22.8b\n"
+    "usubl v30.8h, v30.8b, v22.8b\n"
+    "usubl v29.8h, v29.8b, v22.8b\n"
+    "usubl v28.8h, v28.8b, v22.8b\n"
+    "usubl v27.8h, v27.8b, v22.8b\n"
     "beq 2f\n"
     "1:"  // Loop
-    "smlal v11.4s, v31.4h, v4.4h\n"
-    "ldr x21, [x14, #0x28]\n"
-    "add x16, x16, #0x48\n"
+    "smlal v13.4s, v31.4h, v4.4h\n"
     "smlal2 v26.4s, v31.8h, v4.8h\n"
-    "ldr x20, [x14, #0x30]\n"
-    "subs x12, x12, #0x1\n"
-    "smlal v23.4s, v31.4h, v3.4h\n"
-    "ldr x26, [x14, #0x38]\n"
-    "smlal2 v9.4s, v31.8h, v3.8h\n"
-    "ldr x25, [x14, #0x40]\n"
-    "smlal v12.4s, v31.4h, v1.4h\n"
-    "ldr x19, [x14, #0x48]\n"
-    "smlal2 v22.4s, v31.8h, v1.8h\n"
-    "ldr x24, [x14, #0x50]\n"
-    "smlal v24.4s, v31.4h, v0.4h\n"
-    "ldr x23, [x14, #0x58]\n"
-    "smlal2 v10.4s, v31.8h, v0.8h\n"
-    "ldr d31, [x21, x17]\n"
-    "smlal v11.4s, v30.4h, v0.4h\n"
-    "ldr x22, [x14, #0x60]\n"
+    "ldr x21, [x12, #0x28]\n"
+    "ldr x26, [x12, #0x38]\n"
+    "smlal v19.4s, v31.4h, v3.4h\n"
+    "smlal2 v11.4s, v31.8h, v3.8h\n"
+    "ldr x20, [x12, #0x30]\n"
+    "ldr x25, [x12, #0x40]\n"
+    "smlal v13.4s, v30.4h, v0.4h\n"
     "smlal2 v26.4s, v30.8h, v0.8h\n"
-    "ldr d30, [x19, x17]\n"
-    "smlal v23.4s, v29.4h, v2.4h\n"
-    "ldr x21, [x14, #0x68]\n"
-    "smlal2 v9.4s, v29.8h, v2.8h\n"
-    "ldr d29, [x20, x17]\n"
-    "smlal v11.4s, v28.4h, v5.4h\n"
-    "ldr x20, [x14, #0x70]\n"
+    "ldr x19, [x12, #0x48]\n"
+    "ldr d30, [x19, x15]\n"
+    "smlal v19.4s, v29.4h, v2.4h\n"
+    "smlal2 v11.4s, v29.8h, v2.8h\n"
+    "ldr d29, [x20, x15]\n"
+    "usubl v29.8h, v29.8b, v22.8b\n"
+    "smlal v18.4s, v31.4h, v1.4h\n"
+    "smlal2 v24.4s, v31.8h, v1.8h\n"
+    "ldr x24, [x12, #0x50]\n"
+    "ldr x23, [x12, #0x58]\n"
+    "smlal v9.4s, v31.4h, v0.4h\n"
+    "smlal2 v23.4s, v31.8h, v0.8h\n"
+    "ldr d31, [x21, x15]\n"
+    "usubl v31.8h, v31.8b, v22.8b\n"
+    "smlal v13.4s, v28.4h, v5.4h\n"
     "smlal2 v26.4s, v28.8h, v5.8h\n"
-    "ldr x19, [x14, #0x78]\n"
-    "smlal v23.4s, v28.4h, v4.4h\n"
-    "ldr q25, [x13, #0x0]\n"
-    "smlal2 v9.4s, v28.8h, v4.8h\n"
-    "ldr q18, [x11, #0x0]\n"
-    "smlal v12.4s, v28.4h, v2.4h\n"
-    "ldr q16, [x13, #0x10]\n"
-    "add x13, x13, #0x20\n"
-    "smlal2 v22.4s, v28.8h, v2.8h\n"
-    "ldr q20, [x11, #0x10]\n"
-    "add x11, x11, #0x20\n"
-    "smlal v24.4s, v28.4h, v1.4h\n"
-    "smlal2 v10.4s, v28.8h, v1.8h\n"
-    "ldr d28, [x26, x17]\n"
-    "usubl v31.8h, v31.8b, v21.8b\n"
-    "smlal v11.4s, v27.4h, v7.4h\n"
+    "usubl v30.8h, v30.8b, v22.8b\n"
+    "ldr x22, [x12, #0x60]\n"
+    "smlal v19.4s, v28.4h, v4.4h\n"
+    "smlal2 v11.4s, v28.8h, v4.8h\n"
+    "ldr x21, [x12, #0x68]\n"
+    "ldr x20, [x12, #0x70]\n"
+    "smlal v18.4s, v28.4h, v2.4h\n"
+    "smlal2 v24.4s, v28.8h, v2.8h\n"
+    "ldr x19, [x12, #0x78]\n"
+    "ldr q21, [x13, #0x0]\n"
+    "smlal v9.4s, v28.4h, v1.4h\n"
+    "smlal2 v23.4s, v28.8h, v1.8h\n"
+    "ldr d28, [x26, x15]\n"
+    "usubl v28.8h, v28.8b, v22.8b\n"
+    "smlal v13.4s, v27.4h, v7.4h\n"
     "smlal2 v26.4s, v27.8h, v7.8h\n"
-    "smlal v12.4s, v31.4h, v6.4h\n"
-    "smlal2 v22.4s, v31.8h, v6.8h\n"
-    "ldr d31, [x25, x17]\n"
-    "smlal v23.4s, v27.4h, v6.4h\n"
-    "smlal2 v9.4s, v27.8h, v6.8h\n"
-    "smlal v12.4s, v27.4h, v4.4h\n"
-    "smlal2 v22.4s, v27.8h, v4.8h\n"
-    "smlal v24.4s, v27.4h, v3.4h\n"
-    "smlal2 v10.4s, v27.8h, v3.8h\n"
-    "usubl v29.8h, v29.8b, v21.8b\n"
-    "usubl v28.8h, v28.8b, v21.8b\n"
-    "usubl v31.8h, v31.8b, v21.8b\n"
-    "smlal v24.4s, v29.4h, v8.4h\n"
-    "smlal2 v10.4s, v29.8h, v8.8h\n"
-    "ldr d29, [x24, x17]\n"
-    "smlal v11.4s, v28.4h, v1.4h\n"
+    "ldr q25, [x11, #0x0]\n"
+    "ldr q10, [x13, #0x10]\n"
+    "smlal v19.4s, v27.4h, v6.4h\n"
+    "smlal2 v11.4s, v27.8h, v6.8h\n"
+    "ldr q16, [x11, #0x10]\n"
+    "add x17, x17, #0x48\n"
+    "smlal v18.4s, v31.4h, v6.4h\n"
+    "smlal2 v24.4s, v31.8h, v6.8h\n"
+    "ldr d31, [x25, x15]\n"
+    "usubl v31.8h, v31.8b, v22.8b\n"
+    "smlal v9.4s, v27.4h, v3.4h\n"
+    "smlal2 v23.4s, v27.8h, v3.8h\n"
+    "subs x16, x16, #0x1\n"
+    "add x13, x13, #0x20\n"
+    "smlal v13.4s, v28.4h, v1.4h\n"
     "smlal2 v26.4s, v28.8h, v1.8h\n"
-    "smlal v23.4s, v28.4h, v0.4h\n"
-    "smlal2 v9.4s, v28.8h, v0.8h\n"
-    "ldr d28, [x23, x17]\n"
-    "smlal v11.4s, v31.4h, v2.4h\n"
+    "add x11, x11, #0x20\n"
+    "smlal v19.4s, v28.4h, v0.4h\n"
+    "smlal2 v11.4s, v28.8h, v0.8h\n"
+    "ldr d28, [x23, x15]\n"
+    "usubl v28.8h, v28.8b, v22.8b\n"
+    "smlal v18.4s, v27.4h, v4.4h\n"
+    "smlal v9.4s, v29.4h, v8.4h\n"
+    "smlal2 v24.4s, v27.8h, v4.8h\n"
+    "smlal2 v23.4s, v29.8h, v8.8h\n"
+    "ldr d29, [x24, x15]\n"
+    "usubl v29.8h, v29.8b, v22.8b\n"
+    "smlal v13.4s, v31.4h, v2.4h\n"
     "smlal2 v26.4s, v31.8h, v2.8h\n"
-    "smlal v23.4s, v31.4h, v1.4h\n"
-    "smlal2 v9.4s, v31.8h, v1.8h\n"
-    "ldr d31, [x22, x17]\n"
-    "usubl v30.8h, v30.8b, v21.8b\n"
-    "usubl v29.8h, v29.8b, v21.8b\n"
-    "usubl v28.8h, v28.8b, v21.8b\n"
-    "smlal v11.4s, v30.4h, v8.4h\n"
+    "smlal v19.4s, v31.4h, v1.4h\n"
+    "smlal2 v11.4s, v31.8h, v1.8h\n"
+    "ldr d31, [x22, x15]\n"
+    "usubl v31.8h, v31.8b, v22.8b\n"
+    "smlal v18.4s, v30.4h, v5.4h\n"
+    "smlal v9.4s, v30.4h, v4.4h\n"
+    "smlal v13.4s, v30.4h, v8.4h\n"
     "smlal2 v26.4s, v30.8h, v8.8h\n"
-    "smlal v23.4s, v30.4h, v7.4h\n"
-    "smlal2 v9.4s, v30.8h, v7.8h\n"
-    "smlal v12.4s, v30.4h, v5.4h\n"
-    "smlal2 v22.4s, v30.8h, v5.8h\n"
-    "smlal v24.4s, v30.4h, v4.4h\n"
-    "smlal2 v10.4s, v30.8h, v4.8h\n"
-    "ldr d30, [x21, x17]\n"
-    "smlal v11.4s, v29.4h, v3.4h\n"
+    "smlal v19.4s, v30.4h, v7.4h\n"
+    "smlal2 v11.4s, v30.8h, v7.8h\n"
+    "smlal2 v24.4s, v30.8h, v5.8h\n"
+    "smlal2 v23.4s, v30.8h, v4.8h\n"
+    "ldr d30, [x21, x15]\n"
+    "usubl v30.8h, v30.8b, v22.8b\n"
+    "smlal v18.4s, v29.4h, v0.4h\n"
+    "smlal v9.4s, v28.4h, v2.4h\n"
+    "smlal v13.4s, v29.4h, v3.4h\n"
     "smlal2 v26.4s, v29.8h, v3.8h\n"
-    "smlal v12.4s, v29.4h, v0.4h\n"
-    "smlal2 v22.4s, v29.8h, v0.8h\n"
-    "ldr d29, [x20, x17]\n"
-    "smlal v23.4s, v28.4h, v5.4h\n"
-    "smlal2 v9.4s, v28.8h, v5.8h\n"
-    "smlal v24.4s, v28.4h, v2.4h\n"
-    "smlal2 v10.4s, v28.8h, v2.8h\n"
-    "ldr d28, [x19, x17]\n"
-    "add x17, x17, #0x8\n"
-    "usubl v31.8h, v31.8b, v21.8b\n"
-    "usubl v30.8h, v30.8b, v21.8b\n"
-    "usubl v29.8h, v29.8b, v21.8b\n"
-    "smlal v11.4s, v31.4h, v6.4h\n"
-    "smlal2 v26.4s, v31.8h, v6.8h\n"
-    "smlal v12.4s, v31.4h, v3.4h\n"
-    "smlal2 v22.4s, v31.8h, v3.8h\n"
-    "smlal v23.4s, v30.4h, v8.4h\n"
-    "smlal2 v9.4s, v30.8h, v8.8h\n"
-    "smlal v24.4s, v30.4h, v5.4h\n"
-    "smlal2 v10.4s, v30.8h, v5.8h\n"
-    "smlal v12.4s, v29.4h, v7.4h\n"
-    "smlal2 v22.4s, v29.8h, v7.8h\n"
-    "smlal v24.4s, v29.4h, v6.4h\n"
-    "smlal2 v10.4s, v29.8h, v6.8h\n"
-    "usubl v28.8h, v28.8b, v21.8b\n"
-    "sqrdmulh v11.4s, v11.4s, v25.4s\n"
-    "sqrdmulh v26.4s, v26.4s, v16.4s\n"
-    "smlal v12.4s, v28.4h, v8.4h\n"
-    "smlal2 v22.4s, v28.8h, v8.8h\n"
-    "smlal v24.4s, v28.4h, v7.4h\n"
-    "smlal2 v10.4s, v28.8h, v7.8h\n"
-    "and v19.16b, v11.16b, v18.16b\n"
-    "and v5.16b, v26.16b, v20.16b\n"
-    "sqrdmulh v23.4s, v23.4s, v25.4s\n"
-    "sshr v19.4s, v19.4s, #0x1f\n"
-    "sshr v5.4s, v5.4s, #0x1f\n"
-    "sqrdmulh v9.4s, v9.4s, v16.4s\n"
-    "sqadd v11.4s, v11.4s, v19.4s\n"
-    "sqadd v26.4s, v26.4s, v5.4s\n"
-    "and v28.16b, v23.16b, v18.16b\n"
-    "and v8.16b, v9.16b, v20.16b\n"
-    "srshl v11.4s, v11.4s, v18.4s\n"
-    "srshl v26.4s, v26.4s, v20.4s\n"
-    "sshr v28.4s, v28.4s, #0x1f\n"
-    "sshr v8.4s, v8.4s, #0x1f\n"
-    "add v11.4s, v11.4s, v13.4s\n"
-    "add v26.4s, v26.4s, v13.4s\n"
-    "sqadd v23.4s, v23.4s, v28.4s\n"
-    "smin v11.4s, v11.4s, v14.4s\n"
-    "smin v26.4s, v26.4s, v14.4s\n"
-    "sqadd v9.4s, v9.4s, v8.4s\n"
-    "smax v11.4s, v11.4s, v15.4s\n"
-    "smax v26.4s, v26.4s, v15.4s\n"
-    "srshl v23.4s, v23.4s, v18.4s\n"
-    "srshl v9.4s, v9.4s, v20.4s\n"
-    "uzp1 v11.16b, v11.16b, v26.16b\n"
-    "sqrdmulh v12.4s, v12.4s, v25.4s\n"
-    "uzp1 v11.16b, v11.16b, v11.16b\n"
-    "str d11, [x10, x15]\n"
-    "add v23.4s, v23.4s, v13.4s\n"
-    "add v9.4s, v9.4s, v13.4s\n"
-    "and v1.16b, v12.16b, v18.16b\n"
-    "sqrdmulh v22.4s, v22.4s, v16.4s\n"
-    "smin v23.4s, v23.4s, v14.4s\n"
-    "smin v9.4s, v9.4s, v14.4s\n"
-    "sshr v1.4s, v1.4s, #0x1f\n"
-    "smax v23.4s, v23.4s, v15.4s\n"
-    "smax v9.4s, v9.4s, v15.4s\n"
-    "sqadd v12.4s, v12.4s, v1.4s\n"
-    "and v0.16b, v22.16b, v20.16b\n"
-    "uzp1 v23.16b, v23.16b, v9.16b\n"
-    "sqrdmulh v24.4s, v24.4s, v25.4s\n"
-    "uzp1 v23.16b, v23.16b, v23.16b\n"
-    "str d23, [x9, x15]\n"
-    "srshl v12.4s, v12.4s, v18.4s\n"
-    "sshr v0.4s, v0.4s, #0x1f\n"
-    "and v26.16b, v24.16b, v18.16b\n"
-    "sqrdmulh v10.4s, v10.4s, v16.4s\n"
-    "sqadd v22.4s, v22.4s, v0.4s\n"
-    "add v12.4s, v12.4s, v13.4s\n"
-    "sshr v26.4s, v26.4s, #0x1f\n"
-    "and v16.16b, v10.16b, v20.16b\n"
-    "smin v12.4s, v12.4s, v14.4s\n"
-    "srshl v22.4s, v22.4s, v20.4s\n"
-    "sqadd v24.4s, v24.4s, v26.4s\n"
-    "smax v12.4s, v12.4s, v15.4s\n"
-    "sshr v16.4s, v16.4s, #0x1f\n"
-    "add v22.4s, v22.4s, v13.4s\n"
-    "srshl v24.4s, v24.4s, v18.4s\n"
-    "sqadd v10.4s, v10.4s, v16.4s\n"
-    "smin v22.4s, v22.4s, v14.4s\n"
-    "add v24.4s, v24.4s, v13.4s\n"
-    "smax v22.4s, v22.4s, v15.4s\n"
-    "srshl v10.4s, v10.4s, v20.4s\n"
-    "smin v24.4s, v24.4s, v14.4s\n"
-    "uzp1 v12.16b, v12.16b, v22.16b\n"
-    "add v10.4s, v10.4s, v13.4s\n"
-    "uzp1 v12.16b, v12.16b, v12.16b\n"
-    "str d12, [x28, x15]\n"
-    "smax v24.4s, v24.4s, v15.4s\n"
-    "smin v10.4s, v10.4s, v14.4s\n"
-    "smax v10.4s, v10.4s, v15.4s\n"
-    "uzp1 v24.16b, v24.16b, v10.16b\n"
-    "uzp1 v24.16b, v24.16b, v24.16b\n"
-    "str d24, [x27, x15]\n"
+    "smlal2 v24.4s, v29.8h, v0.8h\n"
+    "ldr d29, [x20, x15]\n"
+    "smlal2 v23.4s, v28.8h, v2.8h\n"
+    "usubl v29.8h, v29.8b, v22.8b\n"
+    "smlal v18.4s, v31.4h, v3.4h\n"
+    "smlal v9.4s, v30.4h, v5.4h\n"
+    "smlal v19.4s, v28.4h, v5.4h\n"
+    "smlal2 v11.4s, v28.8h, v5.8h\n"
+    "ldr d28, [x19, x15]\n"
+    "usubl v28.8h, v28.8b, v22.8b\n"
+    "smlal2 v24.4s, v31.8h, v3.8h\n"
+    "smlal2 v23.4s, v30.8h, v5.8h\n"
     "add x15, x15, #0x8\n"
+    "smlal v18.4s, v29.4h, v7.4h\n"
+    "smlal v9.4s, v29.4h, v6.4h\n"
+    "smlal2 v24.4s, v29.8h, v7.8h\n"
+    "smlal2 v23.4s, v29.8h, v6.8h\n"
+    "smlal v13.4s, v31.4h, v6.4h\n"
+    "smlal v19.4s, v30.4h, v8.4h\n"
+    "sqdmulh v13.4s, v13.4s, v21.4s\n"
+    "smlal v18.4s, v28.4h, v8.4h\n"
+    "smlal v9.4s, v28.4h, v7.4h\n"
+    "sqdmulh v19.4s, v19.4s, v21.4s\n"
+    "smlal2 v26.4s, v31.8h, v6.8h\n"
+    "smlal2 v11.4s, v30.8h, v8.8h\n"
+    "sqdmulh v18.4s, v18.4s, v21.4s\n"
+    "smlal2 v24.4s, v28.8h, v8.8h\n"
+    "smlal2 v23.4s, v28.8h, v7.8h\n"
+    "sqdmulh v9.4s, v9.4s, v21.4s\n"
+    "and v7.16b, v13.16b, v25.16b\n"
+    "sqdmulh v26.4s, v26.4s, v10.4s\n"
+    "and v4.16b, v19.16b, v25.16b\n"
+    "sqdmulh v11.4s, v11.4s, v10.4s\n"
+    "and v21.16b, v18.16b, v25.16b\n"
+    "sqdmulh v24.4s, v24.4s, v10.4s\n"
+    "and v20.16b, v9.16b, v25.16b\n"
+    "sqdmulh v23.4s, v23.4s, v10.4s\n"
+    "sshr v7.4s, v7.4s, #0x1f\n"
+    "and v29.16b, v26.16b, v16.16b\n"
+    "sshr v4.4s, v4.4s, #0x1f\n"
+    "and v10.16b, v11.16b, v16.16b\n"
+    "sshr v21.4s, v21.4s, #0x1f\n"
+    "and v31.16b, v24.16b, v16.16b\n"
+    "sshr v20.4s, v20.4s, #0x1f\n"
+    "and v30.16b, v23.16b, v16.16b\n"
+    "sqadd v13.4s, v13.4s, v7.4s\n"
+    "sshr v29.4s, v29.4s, #0x1f\n"
+    "sqadd v19.4s, v19.4s, v4.4s\n"
+    "sshr v10.4s, v10.4s, #0x1f\n"
+    "sqadd v18.4s, v18.4s, v21.4s\n"
+    "sshr v31.4s, v31.4s, #0x1f\n"
+    "sqadd v9.4s, v9.4s, v20.4s\n"
+    "sshr v30.4s, v30.4s, #0x1f\n"
+    "srshl v13.4s, v13.4s, v25.4s\n"
+    "sqadd v26.4s, v26.4s, v29.4s\n"
+    "srshl v19.4s, v19.4s, v25.4s\n"
+    "sqadd v11.4s, v11.4s, v10.4s\n"
+    "srshl v18.4s, v18.4s, v25.4s\n"
+    "sqadd v24.4s, v24.4s, v31.4s\n"
+    "srshl v9.4s, v9.4s, v25.4s\n"
+    "sqadd v23.4s, v23.4s, v30.4s\n"
+    "srshl v26.4s, v26.4s, v16.4s\n"
+    "sqxtn v13.4h, v13.4s\n"
+    "srshl v11.4s, v11.4s, v16.4s\n"
+    "sqxtn v19.4h, v19.4s\n"
+    "srshl v24.4s, v24.4s, v16.4s\n"
+    "sqxtn v18.4h, v18.4s\n"
+    "srshl v23.4s, v23.4s, v16.4s\n"
+    "sqxtn v9.4h, v9.4s\n"
+    "sqxtn2 v13.8h, v26.4s\n"
+    "sqxtn2 v19.8h, v11.4s\n"
+    "sqxtn2 v18.8h, v24.4s\n"
+    "sqxtn2 v9.8h, v23.4s\n"
+    "sqadd v13.8h, v13.8h, v14.8h\n"
+    "sqadd v19.8h, v19.8h, v14.8h\n"
+    "sqadd v18.8h, v18.8h, v14.8h\n"
+    "sqadd v9.8h, v9.8h, v14.8h\n"
+    "smax v13.8h, v13.8h, v17.8h\n"
+    "smax v19.8h, v19.8h, v17.8h\n"
+    "smax v18.8h, v18.8h, v17.8h\n"
+    "smax v9.8h, v9.8h, v17.8h\n"
+    "smin v13.8h, v13.8h, v15.8h\n"
+    "smin v19.8h, v19.8h, v15.8h\n"
+    "smin v18.8h, v18.8h, v15.8h\n"
+    "smin v9.8h, v9.8h, v15.8h\n"
+    "uzp1 v13.16b, v13.16b, v13.16b\n"
+    "uzp1 v19.16b, v19.16b, v19.16b\n"
+    "str d13, [x10, x14]\n"
+    "uzp1 v18.16b, v18.16b, v18.16b\n"
+    "uzp1 v9.16b, v9.16b, v9.16b\n"
+    "str d19, [x9, x14]\n"
+    "str d18, [x28, x14]\n"
+    "str d9, [x27, x14]\n"
     "ldr x19, [%x[params], %[offsetof_Params_bias]]\n"
-    "ldr q11, [x19, #0x0]\n"
-    "mov v23.16b, v11.16b\n"
+    "ldr q13, [x19, #0x0]\n"
+    "add x14, x14, #0x8\n"
     "ldr q26, [x19, #0x10]\n"
     "add x19, x19, #0x20\n"
-    "mov v12.16b, v11.16b\n"
     "str x19, [%x[params], %[offsetof_Params_bias]]\n"
-    "mov v24.16b, v11.16b\n"
-    "ldr d0, [x16, #0x0]\n"
-    "ldr d1, [x16, #0x8]\n"
-    "mov v9.16b, v26.16b\n"
-    "ldr d2, [x16, #0x10]\n"
-    "mov v22.16b, v26.16b\n"
-    "ldr d3, [x16, #0x18]\n"
-    "mov v10.16b, v26.16b\n"
-    "ldr d4, [x16, #0x20]\n"
-    "ssubl v0.8h, v0.8b, v17.8b\n"
-    "ldr d5, [x16, #0x28]\n"
-    "ssubl v1.8h, v1.8b, v17.8b\n"
-    "ldr d6, [x16, #0x30]\n"
-    "ssubl v2.8h, v2.8b, v17.8b\n"
-    "ldr d7, [x16, #0x38]\n"
-    "ssubl v3.8h, v3.8b, v17.8b\n"
-    "ldr d8, [x16, #0x40]\n"
-    "ssubl v4.8h, v4.8b, v17.8b\n"
-    "ldp x23, x22, [x14, #0x0]\n"
-    "ssubl v5.8h, v5.8b, v17.8b\n"
-    "ldp x21, x20, [x14, #0x10]\n"
-    "ssubl v6.8h, v6.8b, v17.8b\n"
-    "ssubl v7.8h, v7.8b, v17.8b\n"
-    "ldr x19, [x14, #0x20]\n"
-    "ssubl v8.8h, v8.8b, v17.8b\n"
-    "ldr d31, [x23, x17]\n"
-    "usubl v31.8h, v31.8b, v21.8b\n"
-    "ldr d30, [x22, x17]\n"
-    "ldr d29, [x21, x17]\n"
-    "usubl v30.8h, v30.8b, v21.8b\n"
-    "ldr d28, [x20, x17]\n"
-    "usubl v29.8h, v29.8b, v21.8b\n"
-    "ldr d27, [x19, x17]\n"
-    "usubl v28.8h, v28.8b, v21.8b\n"
-    "usubl v27.8h, v27.8b, v21.8b\n"
+    "ldr d0, [x17, #0x0]\n"
+    "ldr d1, [x17, #0x8]\n"
+    "ldr d2, [x17, #0x10]\n"
+    "mov v19.16b, v13.16b\n"
+    "mov v11.16b, v26.16b\n"
+    "ldr d3, [x17, #0x18]\n"
+    "ldr d4, [x17, #0x20]\n"
+    "mov v18.16b, v13.16b\n"
+    "mov v24.16b, v26.16b\n"
+    "ldr d5, [x17, #0x28]\n"
+    "ldr d6, [x17, #0x30]\n"
+    "mov v9.16b, v13.16b\n"
+    "mov v23.16b, v26.16b\n"
+    "ldr d7, [x17, #0x38]\n"
+    "ldr d8, [x17, #0x40]\n"
+    "ssubl v0.8h, v0.8b, v12.8b\n"
+    "ssubl v1.8h, v1.8b, v12.8b\n"
+    "ldp x23, x22, [x12, #0x0]\n"
+    "ldp x21, x20, [x12, #0x10]\n"
+    "ssubl v2.8h, v2.8b, v12.8b\n"
+    "ssubl v3.8h, v3.8b, v12.8b\n"
+    "ldr x19, [x12, #0x20]\n"
+    "ldr d31, [x23, x15]\n"
+    "ssubl v4.8h, v4.8b, v12.8b\n"
+    "ssubl v5.8h, v5.8b, v12.8b\n"
+    "ldr d30, [x22, x15]\n"
+    "ldr d29, [x21, x15]\n"
+    "ssubl v6.8h, v6.8b, v12.8b\n"
+    "ssubl v7.8h, v7.8b, v12.8b\n"
+    "ldr d28, [x20, x15]\n"
+    "ldr d27, [x19, x15]\n"
+    "ssubl v8.8h, v8.8b, v12.8b\n"
+    "usubl v31.8h, v31.8b, v22.8b\n"
+    "usubl v30.8h, v30.8b, v22.8b\n"
+    "usubl v29.8h, v29.8b, v22.8b\n"
+    "usubl v28.8h, v28.8b, v22.8b\n"
+    "usubl v27.8h, v27.8b, v22.8b\n"
     "bgt 1b\n"
     "2:"  // Tail
-    "smlal v11.4s, v31.4h, v4.4h\n"
-    "ldr x21, [x14, #0x28]\n"
-    "tst x8, #0x7\n"
+    "smlal v13.4s, v31.4h, v4.4h\n"
     "smlal2 v26.4s, v31.8h, v4.8h\n"
-    "ldr x20, [x14, #0x30]\n"
-    "smlal v23.4s, v31.4h, v3.4h\n"
-    "ldr x26, [x14, #0x38]\n"
-    "smlal2 v9.4s, v31.8h, v3.8h\n"
-    "ldr x25, [x14, #0x40]\n"
-    "smlal v12.4s, v31.4h, v1.4h\n"
-    "ldr x19, [x14, #0x48]\n"
-    "smlal2 v22.4s, v31.8h, v1.8h\n"
-    "ldr x24, [x14, #0x50]\n"
-    "smlal v24.4s, v31.4h, v0.4h\n"
-    "ldr x23, [x14, #0x58]\n"
-    "smlal2 v10.4s, v31.8h, v0.8h\n"
-    "ldr d31, [x21, x17]\n"
-    "smlal v11.4s, v30.4h, v0.4h\n"
-    "ldr x22, [x14, #0x60]\n"
+    "ldr x21, [x12, #0x28]\n"
+    "ldr x26, [x12, #0x38]\n"
+    "smlal v19.4s, v31.4h, v3.4h\n"
+    "smlal2 v11.4s, v31.8h, v3.8h\n"
+    "ldr x20, [x12, #0x30]\n"
+    "ldr x25, [x12, #0x40]\n"
+    "smlal v13.4s, v30.4h, v0.4h\n"
     "smlal2 v26.4s, v30.8h, v0.8h\n"
-    "ldr d30, [x19, x17]\n"
-    "smlal v23.4s, v29.4h, v2.4h\n"
-    "ldr x21, [x14, #0x68]\n"
-    "smlal2 v9.4s, v29.8h, v2.8h\n"
-    "ldr d29, [x20, x17]\n"
-    "smlal v11.4s, v28.4h, v5.4h\n"
-    "ldr x20, [x14, #0x70]\n"
+    "ldr x19, [x12, #0x48]\n"
+    "ldr d30, [x19, x15]\n"
+    "smlal v19.4s, v29.4h, v2.4h\n"
+    "smlal2 v11.4s, v29.8h, v2.8h\n"
+    "ldr d29, [x20, x15]\n"
+    "usubl v29.8h, v29.8b, v22.8b\n"
+    "smlal v18.4s, v31.4h, v1.4h\n"
+    "smlal2 v24.4s, v31.8h, v1.8h\n"
+    "ldr x24, [x12, #0x50]\n"
+    "ldr x23, [x12, #0x58]\n"
+    "smlal v9.4s, v31.4h, v0.4h\n"
+    "smlal2 v23.4s, v31.8h, v0.8h\n"
+    "ldr d31, [x21, x15]\n"
+    "usubl v31.8h, v31.8b, v22.8b\n"
+    "smlal v13.4s, v28.4h, v5.4h\n"
     "smlal2 v26.4s, v28.8h, v5.8h\n"
-    "ldr x19, [x14, #0x78]\n"
-    "smlal v23.4s, v28.4h, v4.4h\n"
-    "ldr q25, [x13, #0x0]\n"
-    "smlal2 v9.4s, v28.8h, v4.8h\n"
-    "ldr q18, [x11, #0x0]\n"
-    "smlal v12.4s, v28.4h, v2.4h\n"
-    "ldr q16, [x13, #0x10]\n"
-    "add x13, x13, #0x20\n"
-    "smlal2 v22.4s, v28.8h, v2.8h\n"
-    "ldr q20, [x11, #0x10]\n"
-    "add x11, x11, #0x20\n"
-    "smlal v24.4s, v28.4h, v1.4h\n"
-    "smlal2 v10.4s, v28.8h, v1.8h\n"
-    "ldr d28, [x26, x17]\n"
-    "usubl v31.8h, v31.8b, v21.8b\n"
-    "smlal v11.4s, v27.4h, v7.4h\n"
+    "usubl v30.8h, v30.8b, v22.8b\n"
+    "ldr x22, [x12, #0x60]\n"
+    "smlal v19.4s, v28.4h, v4.4h\n"
+    "smlal2 v11.4s, v28.8h, v4.8h\n"
+    "ldr x21, [x12, #0x68]\n"
+    "ldr x20, [x12, #0x70]\n"
+    "smlal v18.4s, v28.4h, v2.4h\n"
+    "smlal2 v24.4s, v28.8h, v2.8h\n"
+    "ldr x19, [x12, #0x78]\n"
+    "ldr q21, [x13, #0x0]\n"
+    "smlal v9.4s, v28.4h, v1.4h\n"
+    "smlal2 v23.4s, v28.8h, v1.8h\n"
+    "ldr d28, [x26, x15]\n"
+    "usubl v28.8h, v28.8b, v22.8b\n"
+    "smlal v13.4s, v27.4h, v7.4h\n"
     "smlal2 v26.4s, v27.8h, v7.8h\n"
-    "smlal v12.4s, v31.4h, v6.4h\n"
-    "smlal2 v22.4s, v31.8h, v6.8h\n"
-    "ldr d31, [x25, x17]\n"
-    "smlal v23.4s, v27.4h, v6.4h\n"
-    "smlal2 v9.4s, v27.8h, v6.8h\n"
-    "smlal v12.4s, v27.4h, v4.4h\n"
-    "smlal2 v22.4s, v27.8h, v4.8h\n"
-    "smlal v24.4s, v27.4h, v3.4h\n"
-    "smlal2 v10.4s, v27.8h, v3.8h\n"
-    "usubl v29.8h, v29.8b, v21.8b\n"
-    "usubl v28.8h, v28.8b, v21.8b\n"
-    "usubl v31.8h, v31.8b, v21.8b\n"
-    "smlal v24.4s, v29.4h, v8.4h\n"
-    "smlal2 v10.4s, v29.8h, v8.8h\n"
-    "ldr d29, [x24, x17]\n"
-    "smlal v11.4s, v28.4h, v1.4h\n"
+    "ldr q25, [x11, #0x0]\n"
+    "ldr q10, [x13, #0x10]\n"
+    "smlal v19.4s, v27.4h, v6.4h\n"
+    "smlal2 v11.4s, v27.8h, v6.8h\n"
+    "ldr q16, [x11, #0x10]\n"
+    "tst x8, #0x7\n"
+    "smlal v18.4s, v31.4h, v6.4h\n"
+    "smlal2 v24.4s, v31.8h, v6.8h\n"
+    "ldr d31, [x25, x15]\n"
+    "usubl v31.8h, v31.8b, v22.8b\n"
+    "smlal v9.4s, v27.4h, v3.4h\n"
+    "smlal2 v23.4s, v27.8h, v3.8h\n"
+    "add x13, x13, #0x20\n"
+    "add x11, x11, #0x20\n"
+    "smlal v13.4s, v28.4h, v1.4h\n"
     "smlal2 v26.4s, v28.8h, v1.8h\n"
-    "smlal v23.4s, v28.4h, v0.4h\n"
-    "smlal2 v9.4s, v28.8h, v0.8h\n"
-    "ldr d28, [x23, x17]\n"
-    "smlal v11.4s, v31.4h, v2.4h\n"
+    "smlal v19.4s, v28.4h, v0.4h\n"
+    "smlal2 v11.4s, v28.8h, v0.8h\n"
+    "ldr d28, [x23, x15]\n"
+    "usubl v28.8h, v28.8b, v22.8b\n"
+    "smlal v18.4s, v27.4h, v4.4h\n"
+    "smlal v9.4s, v29.4h, v8.4h\n"
+    "smlal2 v24.4s, v27.8h, v4.8h\n"
+    "smlal2 v23.4s, v29.8h, v8.8h\n"
+    "ldr d29, [x24, x15]\n"
+    "usubl v29.8h, v29.8b, v22.8b\n"
+    "smlal v13.4s, v31.4h, v2.4h\n"
     "smlal2 v26.4s, v31.8h, v2.8h\n"
-    "smlal v23.4s, v31.4h, v1.4h\n"
-    "smlal2 v9.4s, v31.8h, v1.8h\n"
-    "ldr d31, [x22, x17]\n"
-    "usubl v30.8h, v30.8b, v21.8b\n"
-    "usubl v29.8h, v29.8b, v21.8b\n"
-    "usubl v28.8h, v28.8b, v21.8b\n"
-    "smlal v11.4s, v30.4h, v8.4h\n"
+    "smlal v19.4s, v31.4h, v1.4h\n"
+    "smlal2 v11.4s, v31.8h, v1.8h\n"
+    "ldr d31, [x22, x15]\n"
+    "usubl v31.8h, v31.8b, v22.8b\n"
+    "smlal v18.4s, v30.4h, v5.4h\n"
+    "smlal v9.4s, v30.4h, v4.4h\n"
+    "smlal v13.4s, v30.4h, v8.4h\n"
     "smlal2 v26.4s, v30.8h, v8.8h\n"
-    "smlal v23.4s, v30.4h, v7.4h\n"
-    "smlal2 v9.4s, v30.8h, v7.8h\n"
-    "smlal v12.4s, v30.4h, v5.4h\n"
-    "smlal2 v22.4s, v30.8h, v5.8h\n"
-    "smlal v24.4s, v30.4h, v4.4h\n"
-    "smlal2 v10.4s, v30.8h, v4.8h\n"
-    "ldr d30, [x21, x17]\n"
-    "smlal v11.4s, v29.4h, v3.4h\n"
+    "smlal v19.4s, v30.4h, v7.4h\n"
+    "smlal2 v11.4s, v30.8h, v7.8h\n"
+    "smlal2 v24.4s, v30.8h, v5.8h\n"
+    "smlal2 v23.4s, v30.8h, v4.8h\n"
+    "ldr d30, [x21, x15]\n"
+    "usubl v30.8h, v30.8b, v22.8b\n"
+    "smlal v18.4s, v29.4h, v0.4h\n"
+    "smlal v9.4s, v28.4h, v2.4h\n"
+    "smlal v13.4s, v29.4h, v3.4h\n"
     "smlal2 v26.4s, v29.8h, v3.8h\n"
-    "smlal v12.4s, v29.4h, v0.4h\n"
-    "smlal2 v22.4s, v29.8h, v0.8h\n"
-    "ldr d29, [x20, x17]\n"
-    "smlal v23.4s, v28.4h, v5.4h\n"
-    "smlal2 v9.4s, v28.8h, v5.8h\n"
-    "smlal v24.4s, v28.4h, v2.4h\n"
-    "smlal2 v10.4s, v28.8h, v2.8h\n"
-    "ldr d28, [x19, x17]\n"
-    "add x17, x17, #0x8\n"
-    "usubl v31.8h, v31.8b, v21.8b\n"
-    "usubl v30.8h, v30.8b, v21.8b\n"
-    "usubl v29.8h, v29.8b, v21.8b\n"
-    "smlal v11.4s, v31.4h, v6.4h\n"
-    "smlal2 v26.4s, v31.8h, v6.8h\n"
-    "smlal v12.4s, v31.4h, v3.4h\n"
-    "smlal2 v22.4s, v31.8h, v3.8h\n"
-    "smlal v23.4s, v30.4h, v8.4h\n"
-    "smlal2 v9.4s, v30.8h, v8.8h\n"
-    "smlal v24.4s, v30.4h, v5.4h\n"
-    "smlal2 v10.4s, v30.8h, v5.8h\n"
-    "smlal v12.4s, v29.4h, v7.4h\n"
-    "smlal2 v22.4s, v29.8h, v7.8h\n"
-    "smlal v24.4s, v29.4h, v6.4h\n"
-    "smlal2 v10.4s, v29.8h, v6.8h\n"
-    "usubl v28.8h, v28.8b, v21.8b\n"
-    "sqrdmulh v11.4s, v11.4s, v25.4s\n"
-    "sqrdmulh v26.4s, v26.4s, v16.4s\n"
-    "smlal v12.4s, v28.4h, v8.4h\n"
-    "smlal2 v22.4s, v28.8h, v8.8h\n"
-    "smlal v24.4s, v28.4h, v7.4h\n"
-    "smlal2 v10.4s, v28.8h, v7.8h\n"
-    "and v19.16b, v11.16b, v18.16b\n"
-    "and v5.16b, v26.16b, v20.16b\n"
-    "sqrdmulh v23.4s, v23.4s, v25.4s\n"
-    "sshr v19.4s, v19.4s, #0x1f\n"
-    "sshr v5.4s, v5.4s, #0x1f\n"
-    "sqrdmulh v9.4s, v9.4s, v16.4s\n"
-    "sqadd v11.4s, v11.4s, v19.4s\n"
-    "sqadd v26.4s, v26.4s, v5.4s\n"
-    "and v28.16b, v23.16b, v18.16b\n"
-    "and v8.16b, v9.16b, v20.16b\n"
-    "srshl v11.4s, v11.4s, v18.4s\n"
-    "srshl v26.4s, v26.4s, v20.4s\n"
-    "sshr v28.4s, v28.4s, #0x1f\n"
-    "sshr v8.4s, v8.4s, #0x1f\n"
-    "add v11.4s, v11.4s, v13.4s\n"
-    "add v26.4s, v26.4s, v13.4s\n"
-    "sqadd v23.4s, v23.4s, v28.4s\n"
-    "smin v11.4s, v11.4s, v14.4s\n"
-    "smin v26.4s, v26.4s, v14.4s\n"
-    "sqadd v9.4s, v9.4s, v8.4s\n"
-    "smax v11.4s, v11.4s, v15.4s\n"
-    "smax v26.4s, v26.4s, v15.4s\n"
-    "srshl v23.4s, v23.4s, v18.4s\n"
-    "srshl v9.4s, v9.4s, v20.4s\n"
-    "uzp1 v11.16b, v11.16b, v26.16b\n"
-    "sqrdmulh v12.4s, v12.4s, v25.4s\n"
-    "uzp1 v11.16b, v11.16b, v11.16b\n"
-    "str d11, [x10, x15]\n"
-    "add v23.4s, v23.4s, v13.4s\n"
-    "add v9.4s, v9.4s, v13.4s\n"
-    "and v1.16b, v12.16b, v18.16b\n"
-    "sqrdmulh v22.4s, v22.4s, v16.4s\n"
-    "smin v23.4s, v23.4s, v14.4s\n"
-    "smin v9.4s, v9.4s, v14.4s\n"
-    "sshr v1.4s, v1.4s, #0x1f\n"
-    "smax v23.4s, v23.4s, v15.4s\n"
-    "smax v9.4s, v9.4s, v15.4s\n"
-    "sqadd v12.4s, v12.4s, v1.4s\n"
-    "and v0.16b, v22.16b, v20.16b\n"
-    "uzp1 v23.16b, v23.16b, v9.16b\n"
-    "sqrdmulh v24.4s, v24.4s, v25.4s\n"
-    "uzp1 v23.16b, v23.16b, v23.16b\n"
-    "str d23, [x9, x15]\n"
-    "srshl v12.4s, v12.4s, v18.4s\n"
-    "sshr v0.4s, v0.4s, #0x1f\n"
-    "and v26.16b, v24.16b, v18.16b\n"
-    "sqrdmulh v10.4s, v10.4s, v16.4s\n"
-    "sqadd v22.4s, v22.4s, v0.4s\n"
-    "add v12.4s, v12.4s, v13.4s\n"
-    "sshr v26.4s, v26.4s, #0x1f\n"
-    "and v16.16b, v10.16b, v20.16b\n"
-    "smin v12.4s, v12.4s, v14.4s\n"
-    "srshl v22.4s, v22.4s, v20.4s\n"
-    "sqadd v24.4s, v24.4s, v26.4s\n"
-    "smax v12.4s, v12.4s, v15.4s\n"
-    "sshr v16.4s, v16.4s, #0x1f\n"
-    "add v22.4s, v22.4s, v13.4s\n"
-    "srshl v24.4s, v24.4s, v18.4s\n"
-    "sqadd v10.4s, v10.4s, v16.4s\n"
-    "smin v22.4s, v22.4s, v14.4s\n"
-    "add v24.4s, v24.4s, v13.4s\n"
-    "smax v22.4s, v22.4s, v15.4s\n"
-    "srshl v10.4s, v10.4s, v20.4s\n"
-    "smin v24.4s, v24.4s, v14.4s\n"
-    "uzp1 v12.16b, v12.16b, v22.16b\n"
-    "add v10.4s, v10.4s, v13.4s\n"
-    "uzp1 v12.16b, v12.16b, v12.16b\n"
-    "str d12, [x28, x15]\n"
-    "smax v24.4s, v24.4s, v15.4s\n"
-    "smin v10.4s, v10.4s, v14.4s\n"
-    "smax v10.4s, v10.4s, v15.4s\n"
-    "uzp1 v24.16b, v24.16b, v10.16b\n"
-    "uzp1 v24.16b, v24.16b, v24.16b\n"
-    "str d24, [x27, x15]\n"
+    "smlal2 v24.4s, v29.8h, v0.8h\n"
+    "ldr d29, [x20, x15]\n"
+    "smlal2 v23.4s, v28.8h, v2.8h\n"
+    "usubl v29.8h, v29.8b, v22.8b\n"
+    "smlal v18.4s, v31.4h, v3.4h\n"
+    "smlal v9.4s, v30.4h, v5.4h\n"
+    "smlal v19.4s, v28.4h, v5.4h\n"
+    "smlal2 v11.4s, v28.8h, v5.8h\n"
+    "ldr d28, [x19, x15]\n"
+    "usubl v28.8h, v28.8b, v22.8b\n"
+    "smlal2 v24.4s, v31.8h, v3.8h\n"
+    "smlal2 v23.4s, v30.8h, v5.8h\n"
     "add x15, x15, #0x8\n"
+    "smlal v18.4s, v29.4h, v7.4h\n"
+    "smlal v9.4s, v29.4h, v6.4h\n"
+    "smlal2 v24.4s, v29.8h, v7.8h\n"
+    "smlal2 v23.4s, v29.8h, v6.8h\n"
+    "smlal v13.4s, v31.4h, v6.4h\n"
+    "smlal v19.4s, v30.4h, v8.4h\n"
+    "sqdmulh v13.4s, v13.4s, v21.4s\n"
+    "smlal v18.4s, v28.4h, v8.4h\n"
+    "smlal v9.4s, v28.4h, v7.4h\n"
+    "sqdmulh v19.4s, v19.4s, v21.4s\n"
+    "smlal2 v26.4s, v31.8h, v6.8h\n"
+    "smlal2 v11.4s, v30.8h, v8.8h\n"
+    "sqdmulh v18.4s, v18.4s, v21.4s\n"
+    "smlal2 v24.4s, v28.8h, v8.8h\n"
+    "smlal2 v23.4s, v28.8h, v7.8h\n"
+    "sqdmulh v9.4s, v9.4s, v21.4s\n"
+    "and v7.16b, v13.16b, v25.16b\n"
+    "sqdmulh v26.4s, v26.4s, v10.4s\n"
+    "and v4.16b, v19.16b, v25.16b\n"
+    "sqdmulh v11.4s, v11.4s, v10.4s\n"
+    "and v21.16b, v18.16b, v25.16b\n"
+    "sqdmulh v24.4s, v24.4s, v10.4s\n"
+    "and v20.16b, v9.16b, v25.16b\n"
+    "sqdmulh v23.4s, v23.4s, v10.4s\n"
+    "sshr v7.4s, v7.4s, #0x1f\n"
+    "and v29.16b, v26.16b, v16.16b\n"
+    "sshr v4.4s, v4.4s, #0x1f\n"
+    "and v10.16b, v11.16b, v16.16b\n"
+    "sshr v21.4s, v21.4s, #0x1f\n"
+    "and v31.16b, v24.16b, v16.16b\n"
+    "sshr v20.4s, v20.4s, #0x1f\n"
+    "and v30.16b, v23.16b, v16.16b\n"
+    "sqadd v13.4s, v13.4s, v7.4s\n"
+    "sshr v29.4s, v29.4s, #0x1f\n"
+    "sqadd v19.4s, v19.4s, v4.4s\n"
+    "sshr v10.4s, v10.4s, #0x1f\n"
+    "sqadd v18.4s, v18.4s, v21.4s\n"
+    "sshr v31.4s, v31.4s, #0x1f\n"
+    "sqadd v9.4s, v9.4s, v20.4s\n"
+    "sshr v30.4s, v30.4s, #0x1f\n"
+    "srshl v13.4s, v13.4s, v25.4s\n"
+    "sqadd v26.4s, v26.4s, v29.4s\n"
+    "srshl v19.4s, v19.4s, v25.4s\n"
+    "sqadd v11.4s, v11.4s, v10.4s\n"
+    "srshl v18.4s, v18.4s, v25.4s\n"
+    "sqadd v24.4s, v24.4s, v31.4s\n"
+    "srshl v9.4s, v9.4s, v25.4s\n"
+    "sqadd v23.4s, v23.4s, v30.4s\n"
+    "srshl v26.4s, v26.4s, v16.4s\n"
+    "sqxtn v13.4h, v13.4s\n"
+    "srshl v11.4s, v11.4s, v16.4s\n"
+    "sqxtn v19.4h, v19.4s\n"
+    "srshl v24.4s, v24.4s, v16.4s\n"
+    "sqxtn v18.4h, v18.4s\n"
+    "srshl v23.4s, v23.4s, v16.4s\n"
+    "sqxtn v9.4h, v9.4s\n"
+    "sqxtn2 v13.8h, v26.4s\n"
+    "sqxtn2 v19.8h, v11.4s\n"
+    "sqxtn2 v18.8h, v24.4s\n"
+    "sqxtn2 v9.8h, v23.4s\n"
+    "sqadd v13.8h, v13.8h, v14.8h\n"
+    "sqadd v19.8h, v19.8h, v14.8h\n"
+    "sqadd v18.8h, v18.8h, v14.8h\n"
+    "sqadd v9.8h, v9.8h, v14.8h\n"
+    "smax v13.8h, v13.8h, v17.8h\n"
+    "smax v19.8h, v19.8h, v17.8h\n"
+    "smax v18.8h, v18.8h, v17.8h\n"
+    "smax v9.8h, v9.8h, v17.8h\n"
+    "smin v13.8h, v13.8h, v15.8h\n"
+    "smin v19.8h, v19.8h, v15.8h\n"
+    "smin v18.8h, v18.8h, v15.8h\n"
+    "smin v9.8h, v9.8h, v15.8h\n"
+    "uzp1 v13.16b, v13.16b, v13.16b\n"
+    "uzp1 v19.16b, v19.16b, v19.16b\n"
+    "str d13, [x10, x14]\n"
+    "uzp1 v18.16b, v18.16b, v18.16b\n"
+    "uzp1 v9.16b, v9.16b, v9.16b\n"
+    "str d19, [x9, x14]\n"
+    "str d18, [x28, x14]\n"
+    "str d9, [x27, x14]\n"
+    "add x14, x14, #0x8\n"
     "beq 64f\n"
-    "add x16, x16, #0x48\n"
+    "add x17, x17, #0x48\n"
     "3:"  // Oddments
     "ldr x19, [%x[params], %[offsetof_Params_bias]]\n"
     "tbz x8, #2, 5f\n"
-    "ld1 { v11.4s }, [x19], #0x10\n"
+    "ld1 { v13.4s }, [x19], #0x10\n"
     "tbz x8, #1, 4f\n"
     "ld1 { v26.d }[0], [x19], #0x8\n"
     "tbz x8, #0, 7f\n"
@@ -601,46 +585,46 @@
     "b 7f\n"
     "5:"  // Oddments: Load bias: Bit 2: Unset
     "tbz x8, #1, 6f\n"
-    "ld1 { v11.d }[0], [x19], #0x8\n"
+    "ld1 { v13.d }[0], [x19], #0x8\n"
     "tbz x8, #0, 7f\n"
-    "ld1 { v11.s }[2], [x19]\n"
+    "ld1 { v13.s }[2], [x19]\n"
     "b 7f\n"
     "6:"  // Oddments: Load bias: Bit 2: Unset: Bit 1: Unset
     "tbz x8, #0, 7f\n"
-    "ld1 { v11.s }[0], [x19]\n"
+    "ld1 { v13.s }[0], [x19]\n"
     "7:"  // Oddments: Load bias: Bit 2: End
-    "mov v23.16b, v11.16b\n"
-    "ldr d0, [x16, #0x0]\n"
-    "mov v9.16b, v26.16b\n"
-    "ldr d1, [x16, #0x8]\n"
-    "mov v12.16b, v11.16b\n"
-    "ldr d2, [x16, #0x10]\n"
-    "mov v22.16b, v26.16b\n"
-    "ldr d3, [x16, #0x18]\n"
-    "mov v24.16b, v11.16b\n"
-    "ldr d4, [x16, #0x20]\n"
-    "mov v10.16b, v26.16b\n"
-    "ldr d5, [x16, #0x28]\n"
-    "ssubl v0.8h, v0.8b, v17.8b\n"
-    "ldr d6, [x16, #0x30]\n"
-    "ssubl v1.8h, v1.8b, v17.8b\n"
-    "ldr d7, [x16, #0x38]\n"
-    "ssubl v2.8h, v2.8b, v17.8b\n"
-    "ldr d8, [x16, #0x40]\n"
-    "ssubl v3.8h, v3.8b, v17.8b\n"
-    "ldp x23, x22, [x14, #0x0]\n"
-    "add x23, x23, x17\n"
-    "ssubl v4.8h, v4.8b, v17.8b\n"
-    "ldp x21, x20, [x14, #0x10]\n"
-    "ssubl v5.8h, v5.8b, v17.8b\n"
-    "ldr x19, [x14, #0x20]\n"
-    "ssubl v6.8h, v6.8b, v17.8b\n"
-    "add x22, x22, x17\n"
-    "ssubl v7.8h, v7.8b, v17.8b\n"
-    "add x21, x21, x17\n"
-    "ssubl v8.8h, v8.8b, v17.8b\n"
-    "add x20, x20, x17\n"
-    "add x19, x19, x17\n"
+    "ldr d0, [x17, #0x0]\n"
+    "ldr d1, [x17, #0x8]\n"
+    "mov v19.16b, v13.16b\n"
+    "mov v11.16b, v26.16b\n"
+    "ldr d2, [x17, #0x10]\n"
+    "ldr d3, [x17, #0x18]\n"
+    "mov v18.16b, v13.16b\n"
+    "mov v24.16b, v26.16b\n"
+    "ldr d4, [x17, #0x20]\n"
+    "ldr d5, [x17, #0x28]\n"
+    "mov v9.16b, v13.16b\n"
+    "mov v23.16b, v26.16b\n"
+    "ldr d6, [x17, #0x30]\n"
+    "ldr d7, [x17, #0x38]\n"
+    "ssubl v0.8h, v0.8b, v12.8b\n"
+    "ssubl v1.8h, v1.8b, v12.8b\n"
+    "ldr d8, [x17, #0x40]\n"
+    "ldp x23, x22, [x12, #0x0]\n"
+    "ssubl v2.8h, v2.8b, v12.8b\n"
+    "ssubl v3.8h, v3.8b, v12.8b\n"
+    "ldp x21, x20, [x12, #0x10]\n"
+    "ldr x19, [x12, #0x20]\n"
+    "ssubl v4.8h, v4.8b, v12.8b\n"
+    "ssubl v5.8h, v5.8b, v12.8b\n"
+    "ssubl v6.8h, v6.8b, v12.8b\n"
+    "ssubl v7.8h, v7.8b, v12.8b\n"
+    "ssubl v8.8h, v8.8b, v12.8b\n"
+    "add x23, x23, x15\n"
+    "add x22, x22, x15\n"
+    "add x21, x21, x15\n"
+    "add x20, x20, x15\n"
+    "add x19, x19, x15\n"
     "tbz x8, #2, 9f\n"
     "ld1 { v31.s }[0], [x23], #0x4\n"
     "ld1 { v30.s }[0], [x22], #0x4\n"
@@ -690,33 +674,33 @@
     "ld1 { v28.b }[0], [x20]\n"
     "ld1 { v27.b }[0], [x19]\n"
     "11:"  // Oddments: Initial loads: Bit 2: End
-    "usubl v31.8h, v31.8b, v21.8b\n"
-    "ldr x21, [x14, #0x28]\n"
-    "add x21, x21, x17\n"
-    "usubl v30.8h, v30.8b, v21.8b\n"
-    "usubl v29.8h, v29.8b, v21.8b\n"
-    "usubl v28.8h, v28.8b, v21.8b\n"
-    "usubl v27.8h, v27.8b, v21.8b\n"
-    "smlal v11.4s, v31.4h, v4.4h\n"
+    "usubl v31.8h, v31.8b, v22.8b\n"
+    "smlal v13.4s, v31.4h, v4.4h\n"
     "smlal2 v26.4s, v31.8h, v4.8h\n"
-    "smlal v23.4s, v31.4h, v3.4h\n"
-    "smlal2 v9.4s, v31.8h, v3.8h\n"
-    "smlal v12.4s, v31.4h, v1.4h\n"
-    "smlal2 v22.4s, v31.8h, v1.8h\n"
-    "smlal v24.4s, v31.4h, v0.4h\n"
-    "smlal2 v10.4s, v31.8h, v0.8h\n"
-    "smlal v11.4s, v30.4h, v0.4h\n"
+    "ldr x21, [x12, #0x28]\n"
+    "smlal v19.4s, v31.4h, v3.4h\n"
+    "smlal2 v11.4s, v31.8h, v3.8h\n"
+    "usubl v30.8h, v30.8b, v22.8b\n"
+    "add x21, x21, x15\n"
+    "usubl v29.8h, v29.8b, v22.8b\n"
+    "smlal v18.4s, v31.4h, v1.4h\n"
+    "smlal2 v24.4s, v31.8h, v1.8h\n"
+    "smlal v9.4s, v31.4h, v0.4h\n"
+    "smlal2 v23.4s, v31.8h, v0.8h\n"
+    "usubl v28.8h, v28.8b, v22.8b\n"
+    "smlal v13.4s, v30.4h, v0.4h\n"
     "smlal2 v26.4s, v30.8h, v0.8h\n"
-    "smlal v23.4s, v29.4h, v2.4h\n"
-    "smlal2 v9.4s, v29.8h, v2.8h\n"
-    "smlal v11.4s, v28.4h, v5.4h\n"
+    "usubl v27.8h, v27.8b, v22.8b\n"
+    "smlal v19.4s, v29.4h, v2.4h\n"
+    "smlal2 v11.4s, v29.8h, v2.8h\n"
+    "smlal v13.4s, v28.4h, v5.4h\n"
     "smlal2 v26.4s, v28.8h, v5.8h\n"
-    "smlal v23.4s, v28.4h, v4.4h\n"
-    "smlal2 v9.4s, v28.8h, v4.8h\n"
-    "smlal v12.4s, v28.4h, v2.4h\n"
-    "smlal2 v22.4s, v28.8h, v2.8h\n"
-    "smlal v24.4s, v28.4h, v1.4h\n"
-    "smlal2 v10.4s, v28.8h, v1.8h\n"
+    "smlal v19.4s, v28.4h, v4.4h\n"
+    "smlal2 v11.4s, v28.8h, v4.8h\n"
+    "smlal v18.4s, v28.4h, v2.4h\n"
+    "smlal2 v24.4s, v28.8h, v2.8h\n"
+    "smlal v9.4s, v28.4h, v1.4h\n"
+    "smlal2 v23.4s, v28.8h, v1.8h\n"
     "tbz x8, #2, 13f\n"
     "ld1 { v31.s }[0], [x21], #0x4\n"
     "tbz x8, #1, 12f\n"
@@ -738,19 +722,19 @@
     "tbz x8, #0, 15f\n"
     "ld1 { v31.b }[0], [x21]\n"
     "15:"  // Oddments: Load (3, 0): Bit 2: End
-    "usubl v31.8h, v31.8b, v21.8b\n"
-    "ldr x20, [x14, #0x30]\n"
-    "smlal v11.4s, v27.4h, v7.4h\n"
-    "add x20, x20, x17\n"
-    "smlal v12.4s, v31.4h, v6.4h\n"
-    "smlal2 v22.4s, v31.8h, v6.8h\n"
+    "usubl v31.8h, v31.8b, v22.8b\n"
+    "smlal v18.4s, v31.4h, v6.4h\n"
+    "smlal2 v24.4s, v31.8h, v6.8h\n"
+    "ldr x20, [x12, #0x30]\n"
+    "smlal v13.4s, v27.4h, v7.4h\n"
     "smlal2 v26.4s, v27.8h, v7.8h\n"
-    "smlal v23.4s, v27.4h, v6.4h\n"
-    "smlal2 v9.4s, v27.8h, v6.8h\n"
-    "smlal v12.4s, v27.4h, v4.4h\n"
-    "smlal2 v22.4s, v27.8h, v4.8h\n"
-    "smlal v24.4s, v27.4h, v3.4h\n"
-    "smlal2 v10.4s, v27.8h, v3.8h\n"
+    "add x20, x20, x15\n"
+    "smlal v19.4s, v27.4h, v6.4h\n"
+    "smlal2 v11.4s, v27.8h, v6.8h\n"
+    "smlal v18.4s, v27.4h, v4.4h\n"
+    "smlal2 v24.4s, v27.8h, v4.8h\n"
+    "smlal v9.4s, v27.4h, v3.4h\n"
+    "smlal2 v23.4s, v27.8h, v3.8h\n"
     "tbz x8, #2, 17f\n"
     "ld1 { v29.s }[0], [x20], #0x4\n"
     "tbz x8, #1, 16f\n"
@@ -772,11 +756,11 @@
     "tbz x8, #0, 19f\n"
     "ld1 { v29.b }[0], [x20]\n"
     "19:"  // Oddments: Load (3, 3): Bit 2: End
-    "usubl v29.8h, v29.8b, v21.8b\n"
-    "ldr x26, [x14, #0x38]\n"
-    "smlal v24.4s, v29.4h, v8.4h\n"
-    "add x26, x26, x17\n"
-    "smlal2 v10.4s, v29.8h, v8.8h\n"
+    "usubl v29.8h, v29.8b, v22.8b\n"
+    "ldr x26, [x12, #0x38]\n"
+    "smlal v9.4s, v29.4h, v8.4h\n"
+    "smlal2 v23.4s, v29.8h, v8.8h\n"
+    "add x26, x26, x15\n"
     "tbz x8, #2, 21f\n"
     "ld1 { v28.s }[0], [x26], #0x4\n"
     "tbz x8, #1, 20f\n"
@@ -798,13 +782,13 @@
     "tbz x8, #0, 23f\n"
     "ld1 { v28.b }[0], [x26]\n"
     "23:"  // Oddments: Load (0, 1): Bit 2: End
-    "usubl v28.8h, v28.8b, v21.8b\n"
-    "ldr x25, [x14, #0x40]\n"
-    "smlal v11.4s, v28.4h, v1.4h\n"
-    "add x25, x25, x17\n"
+    "usubl v28.8h, v28.8b, v22.8b\n"
+    "ldr x25, [x12, #0x40]\n"
+    "smlal v13.4s, v28.4h, v1.4h\n"
     "smlal2 v26.4s, v28.8h, v1.8h\n"
-    "smlal v23.4s, v28.4h, v0.4h\n"
-    "smlal2 v9.4s, v28.8h, v0.8h\n"
+    "smlal v19.4s, v28.4h, v0.4h\n"
+    "smlal2 v11.4s, v28.8h, v0.8h\n"
+    "add x25, x25, x15\n"
     "tbz x8, #2, 25f\n"
     "ld1 { v31.s }[0], [x25], #0x4\n"
     "tbz x8, #1, 24f\n"
@@ -826,13 +810,13 @@
     "tbz x8, #0, 27f\n"
     "ld1 { v31.b }[0], [x25]\n"
     "27:"  // Oddments: Load (0, 2): Bit 2: End
-    "usubl v31.8h, v31.8b, v21.8b\n"
-    "ldr x19, [x14, #0x48]\n"
-    "smlal v11.4s, v31.4h, v2.4h\n"
-    "add x19, x19, x17\n"
+    "usubl v31.8h, v31.8b, v22.8b\n"
+    "ldr x19, [x12, #0x48]\n"
+    "smlal v13.4s, v31.4h, v2.4h\n"
     "smlal2 v26.4s, v31.8h, v2.8h\n"
-    "smlal v23.4s, v31.4h, v1.4h\n"
-    "smlal2 v9.4s, v31.8h, v1.8h\n"
+    "smlal v19.4s, v31.4h, v1.4h\n"
+    "smlal2 v11.4s, v31.8h, v1.8h\n"
+    "add x19, x19, x15\n"
     "tbz x8, #2, 29f\n"
     "ld1 { v30.s }[0], [x19], #0x4\n"
     "tbz x8, #1, 28f\n"
@@ -854,17 +838,17 @@
     "tbz x8, #0, 31f\n"
     "ld1 { v30.b }[0], [x19]\n"
     "31:"  // Oddments: Load (2, 2): Bit 2: End
-    "usubl v30.8h, v30.8b, v21.8b\n"
-    "ldr x24, [x14, #0x50]\n"
-    "smlal v11.4s, v30.4h, v8.4h\n"
-    "add x24, x24, x17\n"
+    "usubl v30.8h, v30.8b, v22.8b\n"
+    "ldr x24, [x12, #0x50]\n"
+    "smlal v13.4s, v30.4h, v8.4h\n"
     "smlal2 v26.4s, v30.8h, v8.8h\n"
-    "smlal v23.4s, v30.4h, v7.4h\n"
-    "smlal2 v9.4s, v30.8h, v7.8h\n"
-    "smlal v12.4s, v30.4h, v5.4h\n"
-    "smlal2 v22.4s, v30.8h, v5.8h\n"
-    "smlal v24.4s, v30.4h, v4.4h\n"
-    "smlal2 v10.4s, v30.8h, v4.8h\n"
+    "smlal v19.4s, v30.4h, v7.4h\n"
+    "smlal2 v11.4s, v30.8h, v7.8h\n"
+    "add x24, x24, x15\n"
+    "smlal v18.4s, v30.4h, v5.4h\n"
+    "smlal2 v24.4s, v30.8h, v5.8h\n"
+    "smlal v9.4s, v30.4h, v4.4h\n"
+    "smlal2 v23.4s, v30.8h, v4.8h\n"
     "tbz x8, #2, 33f\n"
     "ld1 { v29.s }[0], [x24], #0x4\n"
     "tbz x8, #1, 32f\n"
@@ -886,13 +870,13 @@
     "tbz x8, #0, 35f\n"
     "ld1 { v29.b }[0], [x24]\n"
     "35:"  // Oddments: Load (1, 0): Bit 2: End
-    "usubl v29.8h, v29.8b, v21.8b\n"
-    "ldr x23, [x14, #0x58]\n"
-    "smlal v11.4s, v29.4h, v3.4h\n"
-    "add x23, x23, x17\n"
+    "usubl v29.8h, v29.8b, v22.8b\n"
+    "ldr x23, [x12, #0x58]\n"
+    "smlal v13.4s, v29.4h, v3.4h\n"
     "smlal2 v26.4s, v29.8h, v3.8h\n"
-    "smlal v12.4s, v29.4h, v0.4h\n"
-    "smlal2 v22.4s, v29.8h, v0.8h\n"
+    "smlal v18.4s, v29.4h, v0.4h\n"
+    "smlal2 v24.4s, v29.8h, v0.8h\n"
+    "add x23, x23, x15\n"
     "tbz x8, #2, 37f\n"
     "ld1 { v28.s }[0], [x23], #0x4\n"
     "tbz x8, #1, 36f\n"
@@ -914,13 +898,13 @@
     "tbz x8, #0, 39f\n"
     "ld1 { v28.b }[0], [x23]\n"
     "39:"  // Oddments: Load (1, 3): Bit 2: End
-    "usubl v28.8h, v28.8b, v21.8b\n"
-    "ldr x22, [x14, #0x60]\n"
-    "smlal v23.4s, v28.4h, v5.4h\n"
-    "add x22, x22, x17\n"
-    "smlal2 v9.4s, v28.8h, v5.8h\n"
-    "smlal v24.4s, v28.4h, v2.4h\n"
-    "smlal2 v10.4s, v28.8h, v2.8h\n"
+    "usubl v28.8h, v28.8b, v22.8b\n"
+    "ldr x22, [x12, #0x60]\n"
+    "smlal v19.4s, v28.4h, v5.4h\n"
+    "smlal2 v11.4s, v28.8h, v5.8h\n"
+    "smlal v9.4s, v28.4h, v2.4h\n"
+    "smlal2 v23.4s, v28.8h, v2.8h\n"
+    "add x22, x22, x15\n"
     "tbz x8, #2, 41f\n"
     "ld1 { v31.s }[0], [x22], #0x4\n"
     "tbz x8, #1, 40f\n"
@@ -942,13 +926,13 @@
     "tbz x8, #0, 43f\n"
     "ld1 { v31.b }[0], [x22]\n"
     "43:"  // Oddments: Load (2, 0): Bit 2: End
-    "usubl v31.8h, v31.8b, v21.8b\n"
-    "ldr x21, [x14, #0x68]\n"
-    "smlal v11.4s, v31.4h, v6.4h\n"
-    "add x21, x21, x17\n"
+    "usubl v31.8h, v31.8b, v22.8b\n"
+    "ldr x21, [x12, #0x68]\n"
+    "smlal v13.4s, v31.4h, v6.4h\n"
     "smlal2 v26.4s, v31.8h, v6.8h\n"
-    "smlal v12.4s, v31.4h, v3.4h\n"
-    "smlal2 v22.4s, v31.8h, v3.8h\n"
+    "smlal v18.4s, v31.4h, v3.4h\n"
+    "smlal2 v24.4s, v31.8h, v3.8h\n"
+    "add x21, x21, x15\n"
     "tbz x8, #2, 45f\n"
     "ld1 { v30.s }[0], [x21], #0x4\n"
     "tbz x8, #1, 44f\n"
@@ -970,13 +954,13 @@
     "tbz x8, #0, 47f\n"
     "ld1 { v30.b }[0], [x21]\n"
     "47:"  // Oddments: Load (2, 3): Bit 2: End
-    "usubl v30.8h, v30.8b, v21.8b\n"
-    "ldr x20, [x14, #0x70]\n"
-    "smlal v23.4s, v30.4h, v8.4h\n"
-    "add x20, x20, x17\n"
-    "smlal2 v9.4s, v30.8h, v8.8h\n"
-    "smlal v24.4s, v30.4h, v5.4h\n"
-    "smlal2 v10.4s, v30.8h, v5.8h\n"
+    "usubl v30.8h, v30.8b, v22.8b\n"
+    "ldr x20, [x12, #0x70]\n"
+    "smlal v19.4s, v30.4h, v8.4h\n"
+    "smlal2 v11.4s, v30.8h, v8.8h\n"
+    "smlal v9.4s, v30.4h, v5.4h\n"
+    "smlal2 v23.4s, v30.8h, v5.8h\n"
+    "add x20, x20, x15\n"
     "tbz x8, #2, 49f\n"
     "ld1 { v29.s }[0], [x20], #0x4\n"
     "tbz x8, #1, 48f\n"
@@ -998,13 +982,13 @@
     "tbz x8, #0, 51f\n"
     "ld1 { v29.b }[0], [x20]\n"
     "51:"  // Oddments: Load (3, 1): Bit 2: End
-    "usubl v29.8h, v29.8b, v21.8b\n"
-    "ldr x19, [x14, #0x78]\n"
-    "smlal v12.4s, v29.4h, v7.4h\n"
-    "add x19, x19, x17\n"
-    "smlal2 v22.4s, v29.8h, v7.8h\n"
-    "smlal v24.4s, v29.4h, v6.4h\n"
-    "smlal2 v10.4s, v29.8h, v6.8h\n"
+    "usubl v29.8h, v29.8b, v22.8b\n"
+    "ldr x19, [x12, #0x78]\n"
+    "smlal v18.4s, v29.4h, v7.4h\n"
+    "smlal2 v24.4s, v29.8h, v7.8h\n"
+    "smlal v9.4s, v29.4h, v6.4h\n"
+    "smlal2 v23.4s, v29.8h, v6.8h\n"
+    "add x19, x19, x15\n"
     "tbz x8, #2, 53f\n"
     "ld1 { v28.s }[0], [x19], #0x4\n"
     "tbz x8, #1, 52f\n"
@@ -1026,160 +1010,150 @@
     "tbz x8, #0, 55f\n"
     "ld1 { v28.b }[0], [x19]\n"
     "55:"  // Oddments: Load (3, 2): Bit 2: End
-    "usubl v28.8h, v28.8b, v21.8b\n"
-    "smlal v12.4s, v28.4h, v8.4h\n"
-    "smlal2 v22.4s, v28.8h, v8.8h\n"
-    "smlal v24.4s, v28.4h, v7.4h\n"
-    "smlal2 v10.4s, v28.8h, v7.8h\n"
+    "usubl v28.8h, v28.8b, v22.8b\n"
+    "smlal v18.4s, v28.4h, v8.4h\n"
+    "smlal2 v24.4s, v28.8h, v8.8h\n"
+    "smlal v9.4s, v28.4h, v7.4h\n"
+    "smlal2 v23.4s, v28.8h, v7.8h\n"
     "tbz x8, #2, 57f\n"
-    "ld1 { v25.4s }, [x13], #0x10\n"
-    "ld1 { v18.4s }, [x11], #0x10\n"
+    "ld1 { v21.4s }, [x13], #0x10\n"
+    "ld1 { v25.4s }, [x11], #0x10\n"
     "tbz x8, #1, 56f\n"
-    "ld1 { v16.d }[0], [x13], #0x8\n"
-    "ld1 { v20.d }[0], [x11], #0x8\n"
+    "ld1 { v10.d }[0], [x13], #0x8\n"
+    "ld1 { v16.d }[0], [x11], #0x8\n"
     "tbz x8, #0, 59f\n"
-    "ld1 { v16.s }[2], [x13]\n"
-    "ld1 { v20.s }[2], [x11]\n"
+    "ld1 { v10.s }[2], [x13]\n"
+    "ld1 { v16.s }[2], [x11]\n"
     "b 59f\n"
     "56:"  // Oddments: Load requant params: Bit 2: Bit 1: Unset
     "tbz x8, #0, 59f\n"
-    "ld1 { v16.s }[0], [x13]\n"
-    "ld1 { v20.s }[0], [x11]\n"
+    "ld1 { v10.s }[0], [x13]\n"
+    "ld1 { v16.s }[0], [x11]\n"
     "b 59f\n"
     "57:"  // Oddments: Load requant params: Bit 2: Unset
     "tbz x8, #1, 58f\n"
-    "ld1 { v25.d }[0], [x13], #0x8\n"
-    "ld1 { v18.d }[0], [x11], #0x8\n"
+    "ld1 { v21.d }[0], [x13], #0x8\n"
+    "ld1 { v25.d }[0], [x11], #0x8\n"
     "tbz x8, #0, 59f\n"
-    "ld1 { v25.s }[2], [x13]\n"
-    "ld1 { v18.s }[2], [x11]\n"
+    "ld1 { v21.s }[2], [x13]\n"
+    "ld1 { v25.s }[2], [x11]\n"
     "b 59f\n"
     "58:"  // Oddments: Load requant params: Bit 2: Unset: Bit 1: Unset
     "tbz x8, #0, 59f\n"
-    "ld1 { v25.s }[0], [x13]\n"
-    "ld1 { v18.s }[0], [x11]\n"
+    "ld1 { v21.s }[0], [x13]\n"
+    "ld1 { v25.s }[0], [x11]\n"
     "59:"  // Oddments: Load requant params: Bit 2: End
-    "sqrdmulh v11.4s, v11.4s, v25.4s\n"
-    "add x10, x10, x15\n"
-    "sqrdmulh v26.4s, v26.4s, v16.4s\n"
-    "add x9, x9, x15\n"
-    "sqrdmulh v23.4s, v23.4s, v25.4s\n"
-    "add x28, x28, x15\n"
-    "sqrdmulh v9.4s, v9.4s, v16.4s\n"
-    "add x27, x27, x15\n"
-    "sqrdmulh v12.4s, v12.4s, v25.4s\n"
-    "and v19.16b, v11.16b, v18.16b\n"
-    "and v5.16b, v26.16b, v20.16b\n"
-    "and v28.16b, v23.16b, v18.16b\n"
-    "sshr v19.4s, v19.4s, #0x1f\n"
-    "sshr v5.4s, v5.4s, #0x1f\n"
-    "sshr v28.4s, v28.4s, #0x1f\n"
-    "sqadd v11.4s, v11.4s, v19.4s\n"
-    "sqadd v26.4s, v26.4s, v5.4s\n"
-    "sqadd v23.4s, v23.4s, v28.4s\n"
-    "and v8.16b, v9.16b, v20.16b\n"
-    "srshl v11.4s, v11.4s, v18.4s\n"
-    "srshl v26.4s, v26.4s, v20.4s\n"
-    "srshl v23.4s, v23.4s, v18.4s\n"
-    "sshr v8.4s, v8.4s, #0x1f\n"
-    "add v11.4s, v11.4s, v13.4s\n"
-    "add v26.4s, v26.4s, v13.4s\n"
-    "add v23.4s, v23.4s, v13.4s\n"
-    "smin v11.4s, v11.4s, v14.4s\n"
-    "smin v26.4s, v26.4s, v14.4s\n"
-    "smin v23.4s, v23.4s, v14.4s\n"
-    "smax v11.4s, v11.4s, v15.4s\n"
-    "smax v26.4s, v26.4s, v15.4s\n"
-    "smax v23.4s, v23.4s, v15.4s\n"
-    "sqadd v9.4s, v9.4s, v8.4s\n"
-    "uzp1 v11.16b, v11.16b, v26.16b\n"
-    "and v1.16b, v12.16b, v18.16b\n"
-    "uzp1 v11.16b, v11.16b, v11.16b\n"
-    "srshl v9.4s, v9.4s, v20.4s\n"
-    "sshr v1.4s, v1.4s, #0x1f\n"
-    "sqrdmulh v22.4s, v22.4s, v16.4s\n"
-    "sqrdmulh v24.4s, v24.4s, v25.4s\n"
-    "add v9.4s, v9.4s, v13.4s\n"
-    "sqadd v12.4s, v12.4s, v1.4s\n"
-    "and v0.16b, v22.16b, v20.16b\n"
-    "smin v9.4s, v9.4s, v14.4s\n"
-    "and v26.16b, v24.16b, v18.16b\n"
-    "srshl v12.4s, v12.4s, v18.4s\n"
-    "smax v9.4s, v9.4s, v15.4s\n"
-    "sshr v0.4s, v0.4s, #0x1f\n"
-    "sshr v26.4s, v26.4s, #0x1f\n"
-    "uzp1 v23.16b, v23.16b, v9.16b\n"
-    "add v12.4s, v12.4s, v13.4s\n"
-    "uzp1 v23.16b, v23.16b, v23.16b\n"
-    "sqadd v22.4s, v22.4s, v0.4s\n"
-    "smin v12.4s, v12.4s, v14.4s\n"
-    "sqadd v24.4s, v24.4s, v26.4s\n"
-    "sqrdmulh v10.4s, v10.4s, v16.4s\n"
-    "smax v12.4s, v12.4s, v15.4s\n"
-    "srshl v22.4s, v22.4s, v20.4s\n"
-    "srshl v24.4s, v24.4s, v18.4s\n"
-    "and v16.16b, v10.16b, v20.16b\n"
-    "add v22.4s, v22.4s, v13.4s\n"
-    "add v24.4s, v24.4s, v13.4s\n"
-    "sshr v16.4s, v16.4s, #0x1f\n"
-    "smin v22.4s, v22.4s, v14.4s\n"
-    "smin v24.4s, v24.4s, v14.4s\n"
-    "sqadd v10.4s, v10.4s, v16.4s\n"
-    "smax v22.4s, v22.4s, v15.4s\n"
-    "smax v24.4s, v24.4s, v15.4s\n"
-    "srshl v10.4s, v10.4s, v20.4s\n"
-    "uzp1 v12.16b, v12.16b, v22.16b\n"
-    "uzp1 v12.16b, v12.16b, v12.16b\n"
-    "add v10.4s, v10.4s, v13.4s\n"
-    "smin v10.4s, v10.4s, v14.4s\n"
-    "smax v10.4s, v10.4s, v15.4s\n"
-    "uzp1 v24.16b, v24.16b, v10.16b\n"
-    "uzp1 v24.16b, v24.16b, v24.16b\n"
+    "sqdmulh v13.4s, v13.4s, v21.4s\n"
+    "sqdmulh v19.4s, v19.4s, v21.4s\n"
+    "add x10, x10, x14\n"
+    "add x9, x9, x14\n"
+    "sqdmulh v18.4s, v18.4s, v21.4s\n"
+    "sqdmulh v9.4s, v9.4s, v21.4s\n"
+    "add x28, x28, x14\n"
+    "add x27, x27, x14\n"
+    "and v7.16b, v13.16b, v25.16b\n"
+    "sqdmulh v26.4s, v26.4s, v10.4s\n"
+    "and v4.16b, v19.16b, v25.16b\n"
+    "sqdmulh v11.4s, v11.4s, v10.4s\n"
+    "and v21.16b, v18.16b, v25.16b\n"
+    "sqdmulh v24.4s, v24.4s, v10.4s\n"
+    "and v20.16b, v9.16b, v25.16b\n"
+    "sqdmulh v23.4s, v23.4s, v10.4s\n"
+    "sshr v7.4s, v7.4s, #0x1f\n"
+    "and v29.16b, v26.16b, v16.16b\n"
+    "sshr v4.4s, v4.4s, #0x1f\n"
+    "and v10.16b, v11.16b, v16.16b\n"
+    "sshr v21.4s, v21.4s, #0x1f\n"
+    "and v31.16b, v24.16b, v16.16b\n"
+    "sshr v20.4s, v20.4s, #0x1f\n"
+    "and v30.16b, v23.16b, v16.16b\n"
+    "sqadd v13.4s, v13.4s, v7.4s\n"
+    "sshr v29.4s, v29.4s, #0x1f\n"
+    "sqadd v19.4s, v19.4s, v4.4s\n"
+    "sshr v10.4s, v10.4s, #0x1f\n"
+    "sqadd v18.4s, v18.4s, v21.4s\n"
+    "sshr v31.4s, v31.4s, #0x1f\n"
+    "sqadd v9.4s, v9.4s, v20.4s\n"
+    "sshr v30.4s, v30.4s, #0x1f\n"
+    "srshl v13.4s, v13.4s, v25.4s\n"
+    "sqadd v26.4s, v26.4s, v29.4s\n"
+    "srshl v19.4s, v19.4s, v25.4s\n"
+    "sqadd v11.4s, v11.4s, v10.4s\n"
+    "srshl v18.4s, v18.4s, v25.4s\n"
+    "sqadd v24.4s, v24.4s, v31.4s\n"
+    "srshl v9.4s, v9.4s, v25.4s\n"
+    "sqadd v23.4s, v23.4s, v30.4s\n"
+    "srshl v26.4s, v26.4s, v16.4s\n"
+    "sqxtn v13.4h, v13.4s\n"
+    "srshl v11.4s, v11.4s, v16.4s\n"
+    "sqxtn v19.4h, v19.4s\n"
+    "srshl v24.4s, v24.4s, v16.4s\n"
+    "sqxtn v18.4h, v18.4s\n"
+    "srshl v23.4s, v23.4s, v16.4s\n"
+    "sqxtn v9.4h, v9.4s\n"
+    "sqxtn2 v13.8h, v26.4s\n"
+    "sqxtn2 v19.8h, v11.4s\n"
+    "sqxtn2 v18.8h, v24.4s\n"
+    "sqxtn2 v9.8h, v23.4s\n"
+    "sqadd v13.8h, v13.8h, v14.8h\n"
+    "sqadd v19.8h, v19.8h, v14.8h\n"
+    "sqadd v18.8h, v18.8h, v14.8h\n"
+    "sqadd v9.8h, v9.8h, v14.8h\n"
+    "smax v13.8h, v13.8h, v17.8h\n"
+    "smax v19.8h, v19.8h, v17.8h\n"
+    "smax v18.8h, v18.8h, v17.8h\n"
+    "smax v9.8h, v9.8h, v17.8h\n"
+    "smin v13.8h, v13.8h, v15.8h\n"
+    "smin v19.8h, v19.8h, v15.8h\n"
+    "smin v18.8h, v18.8h, v15.8h\n"
+    "smin v9.8h, v9.8h, v15.8h\n"
+    "uzp1 v13.16b, v13.16b, v13.16b\n"
+    "uzp1 v19.16b, v19.16b, v19.16b\n"
+    "uzp1 v18.16b, v18.16b, v18.16b\n"
+    "uzp1 v9.16b, v9.16b, v9.16b\n"
     "tbz x8, #2, 61f\n"
-    "st1 { v11.s }[0], [x10], #0x4\n"
-    "st1 { v23.s }[0], [x9], #0x4\n"
-    "st1 { v12.s }[0], [x28], #0x4\n"
-    "st1 { v24.s }[0], [x27], #0x4\n"
+    "st1 { v13.s }[0], [x10], #0x4\n"
+    "st1 { v19.s }[0], [x9], #0x4\n"
+    "st1 { v18.s }[0], [x28], #0x4\n"
+    "st1 { v9.s }[0], [x27], #0x4\n"
     "tbz x8, #1, 60f\n"
-    "st1 { v11.h }[2], [x10], #0x2\n"
-    "st1 { v23.h }[2], [x9], #0x2\n"
-    "st1 { v12.h }[2], [x28], #0x2\n"
-    "st1 { v24.h }[2], [x27], #0x2\n"
+    "st1 { v13.h }[2], [x10], #0x2\n"
+    "st1 { v19.h }[2], [x9], #0x2\n"
+    "st1 { v18.h }[2], [x28], #0x2\n"
+    "st1 { v9.h }[2], [x27], #0x2\n"
     "tbz x8, #0, 63f\n"
-    "st1 { v11.b }[6], [x10], #0x1\n"
-    "st1 { v23.b }[6], [x9], #0x1\n"
-    "st1 { v12.b }[6], [x28], #0x1\n"
-    "st1 { v24.b }[6], [x27], #0x1\n"
+    "st1 { v13.b }[6], [x10], #0x1\n"
+    "st1 { v19.b }[6], [x9], #0x1\n"
+    "st1 { v18.b }[6], [x28], #0x1\n"
+    "st1 { v9.b }[6], [x27], #0x1\n"
     "b 63f\n"
     "60:"  // Oddments: Bit 2: Bit 1: Unset
     "tbz x8, #0, 63f\n"
-    "st1 { v11.b }[4], [x10], #0x1\n"
-    "st1 { v23.b }[4], [x9], #0x1\n"
-    "st1 { v12.b }[4], [x28], #0x1\n"
-    "st1 { v24.b }[4], [x27], #0x1\n"
+    "st1 { v13.b }[4], [x10], #0x1\n"
+    "st1 { v19.b }[4], [x9], #0x1\n"
+    "st1 { v18.b }[4], [x28], #0x1\n"
+    "st1 { v9.b }[4], [x27], #0x1\n"
     "b 63f\n"
     "61:"  // Oddments: Bit 2: Unset
     "tbz x8, #1, 62f\n"
-    "st1 { v11.h }[0], [x10], #0x2\n"
-    "st1 { v23.h }[0], [x9], #0x2\n"
-    "st1 { v12.h }[0], [x28], #0x2\n"
-    "st1 { v24.h }[0], [x27], #0x2\n"
+    "st1 { v13.h }[0], [x10], #0x2\n"
+    "st1 { v19.h }[0], [x9], #0x2\n"
+    "st1 { v18.h }[0], [x28], #0x2\n"
+    "st1 { v9.h }[0], [x27], #0x2\n"
     "tbz x8, #0, 63f\n"
-    "st1 { v11.b }[2], [x10], #0x1\n"
-    "st1 { v23.b }[2], [x9], #0x1\n"
-    "st1 { v12.b }[2], [x28], #0x1\n"
-    "st1 { v24.b }[2], [x27], #0x1\n"
+    "st1 { v13.b }[2], [x10], #0x1\n"
+    "st1 { v19.b }[2], [x9], #0x1\n"
+    "st1 { v18.b }[2], [x28], #0x1\n"
+    "st1 { v9.b }[2], [x27], #0x1\n"
     "b 63f\n"
     "62:"  // Oddments: Bit 2: Unset: Bit 1: Unset
     "tbz x8, #0, 63f\n"
-    "st1 { v11.b }[0], [x10], #0x1\n"
-    "st1 { v23.b }[0], [x9], #0x1\n"
-    "st1 { v12.b }[0], [x28], #0x1\n"
-    "st1 { v24.b }[0], [x27], #0x1\n"
+    "st1 { v13.b }[0], [x10], #0x1\n"
+    "st1 { v19.b }[0], [x9], #0x1\n"
+    "st1 { v18.b }[0], [x28], #0x1\n"
+    "st1 { v9.b }[0], [x27], #0x1\n"
     "63:"  // Oddments: Bit 2: End
-
     "64:"  // End
-
     :
     : [offsetof_Params_bias] "I" (offsetof(Params, bias)), [offsetof_Params_inptrs] "I" (offsetof(Params, inptrs)), [offsetof_Params_n_channels] "I" (offsetof(Params, n_channels)), [offsetof_Params_outptrs] "I" (offsetof(Params, outptrs)), [offsetof_Params_requant] "I" (offsetof(Params, requant)), [offsetof_Params_requant_muls] "I" (offsetof(Params, requant_muls)), [offsetof_Params_requant_shifts] "I" (offsetof(Params, requant_shifts)), [offsetof_Params_weights] "I" (offsetof(Params, weights)), [offsetof_Requantize32_a_offset] "I" (offsetof(arm_gemm::Requantize32, a_offset)), [offsetof_Requantize32_b_offset] "I" (offsetof(arm_gemm::Requantize32, b_offset)), [offsetof_Requantize32_c_offset] "I" (offsetof(arm_gemm::Requantize32, c_offset)), [offsetof_Requantize32_maxval] "I" (offsetof(arm_gemm::Requantize32, maxval)), [offsetof_Requantize32_minval] "I" (offsetof(arm_gemm::Requantize32, minval)), [params] "r" (&params)
     : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x8", "x9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17", "x19", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8s8u8q_nhwc_3x3_s2_output2x2_mla_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8s8u8q_nhwc_3x3_s2_output2x2_mla_depthfirst.hpp
index 77861e9..9a1b64e 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8s8u8q_nhwc_3x3_s2_output2x2_mla_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8s8u8q_nhwc_3x3_s2_output2x2_mla_depthfirst.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -36,37 +36,24 @@
 
 void a64_u8s8u8q_nhwc_3x3_s2_output2x2_mla_depthfirst_impl(unsigned int, const uint8_t *const *, const int8_t *, const int32_t *, const arm_gemm::Requantize32 &, const int32_t *, const int32_t *, uint8_t *const *);
 
-struct a64_u8s8u8q_nhwc_3x3_s2_output2x2_mla_depthfirst
+class a64_u8s8u8q_nhwc_3x3_s2_output2x2_mla_depthfirst : public DepthwiseDepthfirstStrategy<uint8_t, int8_t, uint8_t, int32_t>
 {
-  typedef int32_t bias_type;
-  typedef uint8_t input_type;
-  typedef int8_t weight_type;
-  typedef uint8_t return_type;
+  using Parent = DepthwiseDepthfirstStrategy<uint8_t, int8_t, uint8_t, int32_t>;
 
-  constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::None;
-
-  typedef void (*kern_type)(unsigned int, const uint8_t *const *, const int8_t *, const int32_t *, const arm_gemm::Requantize32 &, const int32_t *, const int32_t *, uint8_t *const *);
-  typedef void (*parameter_packing_fn)(unsigned int, void *, const int8_t *, size_t, size_t);
-  typedef size_t (*parameter_sizing_fn)(const DepthwiseArgs &);
-
+  public:
   constexpr static unsigned int kernel_rows = 3;
   constexpr static unsigned int kernel_cols = 3;
 
   constexpr static unsigned int stride_rows = 2;
   constexpr static unsigned int stride_cols = 2;
 
-  constexpr static unsigned int output_rows = 2;
-  constexpr static unsigned int output_cols = 2;
+  a64_u8s8u8q_nhwc_3x3_s2_output2x2_mla_depthfirst(const CPUInfo *) : Parent(2, 2, 3, 3, 2, 2) {}
 
-  constexpr static unsigned int input_rows = 5;
-  constexpr static unsigned int input_cols = 5;
+  arm_gemm::VLType get_vl_type(void) const override { return arm_gemm::VLType::None; }
 
-  constexpr static parameter_packing_fn pack_parameters = interleave_a64_s8q_3x3_mla::pack_parameters;
-  constexpr static parameter_sizing_fn get_packed_size = interleave_a64_s8q_3x3_mla::get_packed_size;
-
-  kern_type kernel = a64_u8s8u8q_nhwc_3x3_s2_output2x2_mla_depthfirst_impl;
-
-  a64_u8s8u8q_nhwc_3x3_s2_output2x2_mla_depthfirst(const CPUInfo *) {}
+  Parent::KernelType kernel = a64_u8s8u8q_nhwc_3x3_s2_output2x2_mla_depthfirst_impl;
+  Parent::KernelType get_kernel(void) const override { return kernel; }
+  unsigned int get_accumulator_depth_vl(void) const override { return 2; }
 };
 
 }  // namespace depthwise
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8s8u8q_nhwc_3x3_s2_output2x2_mla_depthfirst/generic.cpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8s8u8q_nhwc_3x3_s2_output2x2_mla_depthfirst/generic.cpp
index 4e1586b..790d26b 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8s8u8q_nhwc_3x3_s2_output2x2_mla_depthfirst/generic.cpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8s8u8q_nhwc_3x3_s2_output2x2_mla_depthfirst/generic.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -46,7 +46,7 @@
   struct Params
   {
     long unsigned int n_channels;
-    const int8_t *weights;
+    const void *weights;
     const int32_t *bias;
     const arm_gemm::Requantize32 *requant;
     const int32_t *const requant_muls;
@@ -57,7 +57,7 @@
     Params(
       long unsigned int n_channels,
       const uint8_t *const *inptrs_raw,
-      const int8_t *const weights,
+      const void *const weights,
       const int32_t *const bias,
       const arm_gemm::Requantize32 &qp,
       const int32_t *const requant_muls,
@@ -100,611 +100,595 @@
                       requant_muls, requant_shifts, outptrs);
 
   __asm__ __volatile__(
-    "ldr x3, [%x[params], %[offsetof_Params_n_channels]]\n"
-    "mov x4, #0x0\n"
-    "ldr x5, [%x[params], %[offsetof_Params_weights]]\n"
-    "mov x6, #0x0\n"
-    "ldr x22, [%x[params], %[offsetof_Params_requant]]\n"
-    "add x7, %x[params], %[offsetof_Params_inptrs]\n"
-    "ldr x8, [%x[params], %[offsetof_Params_requant_muls]]\n"
-    "lsr x17, x3, #0x3\n"
-    "ldr x16, [%x[params], %[offsetof_Params_requant_shifts]]\n"
-    "add x19, x22, %[offsetof_Requantize32_a_offset]\n"
-    "ldr x21, [%x[params], %[offsetof_Params_outptrs]]\n"
-    "add x20, x22, %[offsetof_Requantize32_b_offset]\n"
-    "ld1r { v22.16b }, [x19]\n"
-    "add x19, x22, %[offsetof_Requantize32_c_offset]\n"
-    "ld1r { v12.16b }, [x20]\n"
-    "add x20, x22, %[offsetof_Requantize32_minval]\n"
-    "ld1r { v14.4s }, [x19]\n"
-    "add x19, x22, %[offsetof_Requantize32_maxval]\n"
-    "ld1r { v16.4s }, [x20]\n"
-    "ld1r { v15.4s }, [x19]\n"
-    "ldp x15, x14, [x21, #0x0]\n"
-    "ldp x13, x12, [x21, #0x10]\n"
-    "cbz x17, 3f\n"
-    "subs x17, x17, #0x1\n"
+    "ldr x19, [%x[params], %[offsetof_Params_requant]]\n"
+    "ldr x8, [%x[params], %[offsetof_Params_n_channels]]\n"
+    "add x24, x19, %[offsetof_Requantize32_a_offset]\n"
+    "add x23, x19, %[offsetof_Requantize32_b_offset]\n"
+    "ldr x22, [%x[params], %[offsetof_Params_outptrs]]\n"
+    "add x21, x19, %[offsetof_Requantize32_c_offset]\n"
+    "add x20, x19, %[offsetof_Requantize32_minval]\n"
+    "ldr x17, [%x[params], %[offsetof_Params_weights]]\n"
+    "add x19, x19, %[offsetof_Requantize32_maxval]\n"
+    "ld1r { v12.16b }, [x24]\n"
+    "ld1r { v13.16b }, [x23]\n"
+    "lsr x16, x8, #0x3\n"
+    "ld1r { v11.8h }, [x21]\n"
+    "ld1r { v17.8h }, [x20]\n"
+    "mov x15, #0x0\n"
+    "mov x14, #0x0\n"
+    "ld1r { v14.8h }, [x19]\n"
+    "ldr x13, [%x[params], %[offsetof_Params_requant_muls]]\n"
+    "add x12, %x[params], %[offsetof_Params_inptrs]\n"
+    "ldr x11, [%x[params], %[offsetof_Params_requant_shifts]]\n"
+    "ldp x10, x9, [x22, #0x0]\n"
+    "ldp x28, x27, [x22, #0x10]\n"
+    "cbz x16, 3f\n"
     "ldr x19, [%x[params], %[offsetof_Params_bias]]\n"
-    "ldr q13, [x19, #0x0]\n"
-    "mov v19.16b, v13.16b\n"
+    "ldr q15, [x19, #0x0]\n"
+    "subs x16, x16, #0x1\n"
+    "mov v9.16b, v15.16b\n"
     "ldr q10, [x19, #0x10]\n"
     "add x19, x19, #0x20\n"
-    "mov v11.16b, v13.16b\n"
     "str x19, [%x[params], %[offsetof_Params_bias]]\n"
-    "mov v18.16b, v13.16b\n"
-    "ldr d0, [x5, #0x0]\n"
-    "ldr d1, [x5, #0x8]\n"
-    "mov v20.16b, v10.16b\n"
-    "ldr d2, [x5, #0x10]\n"
-    "mov v17.16b, v10.16b\n"
-    "ldr d3, [x5, #0x18]\n"
+    "ldr d0, [x17, #0x0]\n"
+    "ldr d1, [x17, #0x8]\n"
+    "ldr d2, [x17, #0x10]\n"
+    "mov v16.16b, v10.16b\n"
+    "mov v22.16b, v15.16b\n"
+    "ldr d3, [x17, #0x18]\n"
+    "ldr d4, [x17, #0x20]\n"
     "mov v21.16b, v10.16b\n"
-    "ldr d4, [x5, #0x20]\n"
-    "ssubl v0.8h, v0.8b, v12.8b\n"
-    "ldr d5, [x5, #0x28]\n"
-    "ssubl v1.8h, v1.8b, v12.8b\n"
-    "ldr d6, [x5, #0x30]\n"
-    "ssubl v2.8h, v2.8b, v12.8b\n"
-    "ldr d7, [x5, #0x38]\n"
-    "ssubl v3.8h, v3.8b, v12.8b\n"
-    "ldr d8, [x5, #0x40]\n"
-    "ssubl v4.8h, v4.8b, v12.8b\n"
-    "ldp x26, x25, [x7, #0x0]\n"
-    "ssubl v5.8h, v5.8b, v12.8b\n"
-    "ldp x24, x23, [x7, #0x10]\n"
-    "ssubl v6.8h, v6.8b, v12.8b\n"
-    "ssubl v7.8h, v7.8b, v12.8b\n"
-    "ldp x22, x21, [x7, #0x20]\n"
-    "ssubl v8.8h, v8.8b, v12.8b\n"
-    "ldp x20, x19, [x7, #0x30]\n"
-    "ldr d31, [x26, x4]\n"
-    "usubl v31.8h, v31.8b, v22.8b\n"
-    "ldr d30, [x25, x4]\n"
-    "ldr d29, [x24, x4]\n"
-    "usubl v30.8h, v30.8b, v22.8b\n"
-    "ldr d28, [x23, x4]\n"
-    "usubl v29.8h, v29.8b, v22.8b\n"
-    "ldr d27, [x22, x4]\n"
-    "ldr d26, [x21, x4]\n"
-    "usubl v28.8h, v28.8b, v22.8b\n"
-    "ldr d25, [x20, x4]\n"
-    "ldr d24, [x19, x4]\n"
-    "usubl v27.8h, v27.8b, v22.8b\n"
-    "usubl v26.8h, v26.8b, v22.8b\n"
-    "usubl v25.8h, v25.8b, v22.8b\n"
-    "usubl v24.8h, v24.8b, v22.8b\n"
+    "mov v23.16b, v15.16b\n"
+    "ldr d5, [x17, #0x28]\n"
+    "ldr d6, [x17, #0x30]\n"
+    "mov v18.16b, v10.16b\n"
+    "ssubl v0.8h, v0.8b, v13.8b\n"
+    "ldr d7, [x17, #0x38]\n"
+    "ldr d8, [x17, #0x40]\n"
+    "ssubl v1.8h, v1.8b, v13.8b\n"
+    "ssubl v2.8h, v2.8b, v13.8b\n"
+    "ldp x26, x25, [x12, #0x0]\n"
+    "ldp x24, x23, [x12, #0x10]\n"
+    "ssubl v3.8h, v3.8b, v13.8b\n"
+    "ssubl v4.8h, v4.8b, v13.8b\n"
+    "ldp x22, x21, [x12, #0x20]\n"
+    "ldp x20, x19, [x12, #0x30]\n"
+    "ssubl v5.8h, v5.8b, v13.8b\n"
+    "ssubl v6.8h, v6.8b, v13.8b\n"
+    "ldr d31, [x26, x15]\n"
+    "ldr d30, [x25, x15]\n"
+    "ssubl v7.8h, v7.8b, v13.8b\n"
+    "ssubl v8.8h, v8.8b, v13.8b\n"
+    "ldr d29, [x24, x15]\n"
+    "ldr d28, [x23, x15]\n"
+    "usubl v31.8h, v31.8b, v12.8b\n"
+    "usubl v30.8h, v30.8b, v12.8b\n"
+    "ldr d27, [x22, x15]\n"
+    "ldr d26, [x21, x15]\n"
+    "usubl v29.8h, v29.8b, v12.8b\n"
+    "usubl v28.8h, v28.8b, v12.8b\n"
+    "ldr d25, [x20, x15]\n"
+    "ldr d24, [x19, x15]\n"
+    "usubl v27.8h, v27.8b, v12.8b\n"
+    "usubl v26.8h, v26.8b, v12.8b\n"
+    "usubl v25.8h, v25.8b, v12.8b\n"
+    "usubl v24.8h, v24.8b, v12.8b\n"
     "beq 2f\n"
     "1:"  // Loop
-    "smlal v13.4s, v31.4h, v8.4h\n"
-    "ldr x22, [x7, #0x40]\n"
-    "add x5, x5, #0x48\n"
+    "smlal v15.4s, v31.4h, v8.4h\n"
     "smlal2 v10.4s, v31.8h, v8.8h\n"
-    "ldr x21, [x7, #0x48]\n"
-    "subs x17, x17, #0x1\n"
-    "smlal v19.4s, v31.4h, v6.4h\n"
-    "ldr x20, [x7, #0x50]\n"
-    "smlal2 v20.4s, v31.8h, v6.8h\n"
-    "ldr x19, [x7, #0x58]\n"
-    "smlal v11.4s, v31.4h, v2.4h\n"
-    "ldr x11, [x7, #0x60]\n"
-    "smlal2 v17.4s, v31.8h, v2.8h\n"
-    "ldr x10, [x7, #0x68]\n"
-    "smlal v18.4s, v31.4h, v0.4h\n"
-    "ldr x9, [x7, #0x70]\n"
-    "smlal2 v21.4s, v31.8h, v0.8h\n"
-    "ldr x28, [x7, #0x78]\n"
-    "smlal v13.4s, v30.4h, v0.4h\n"
-    "ldr x27, [x7, #0x80]\n"
+    "ldr x24, [x12, #0x40]\n"
+    "ldr x23, [x12, #0x48]\n"
+    "smlal v9.4s, v31.4h, v6.4h\n"
+    "smlal2 v16.4s, v31.8h, v6.8h\n"
+    "ldr x21, [x12, #0x50]\n"
+    "ldr x19, [x12, #0x58]\n"
+    "smlal v15.4s, v30.4h, v0.4h\n"
     "smlal2 v10.4s, v30.8h, v0.8h\n"
-    "ldr x26, [x7, #0x88]\n"
-    "smlal v19.4s, v28.4h, v1.4h\n"
-    "ldr x25, [x7, #0x90]\n"
-    "smlal2 v20.4s, v28.8h, v1.8h\n"
-    "ldr d28, [x21, x4]\n"
-    "smlal v13.4s, v29.4h, v1.4h\n"
-    "ldr x24, [x7, #0x98]\n"
+    "ldr x22, [x12, #0x78]\n"
+    "ldr x20, [x12, #0x60]\n"
+    "smlal v9.4s, v28.4h, v1.4h\n"
+    "smlal2 v16.4s, v28.8h, v1.8h\n"
+    "ldr d28, [x23, x15]\n"
+    "usubl v28.8h, v28.8b, v12.8b\n"
+    "smlal v15.4s, v29.4h, v1.4h\n"
     "smlal2 v10.4s, v29.8h, v1.8h\n"
-    "ldr d29, [x22, x4]\n"
-    "smlal v19.4s, v27.4h, v2.4h\n"
-    "ldr x23, [x7, #0xa0]\n"
-    "smlal2 v20.4s, v27.8h, v2.8h\n"
-    "ldr d27, [x20, x4]\n"
-    "smlal v13.4s, v26.4h, v3.4h\n"
-    "ldr x22, [x7, #0xa8]\n"
+    "ldr d29, [x24, x15]\n"
+    "usubl v29.8h, v29.8b, v12.8b\n"
+    "smlal v9.4s, v27.4h, v2.4h\n"
+    "smlal2 v16.4s, v27.8h, v2.8h\n"
+    "ldr d27, [x21, x15]\n"
+    "usubl v27.8h, v27.8b, v12.8b\n"
+    "smlal v15.4s, v26.4h, v3.4h\n"
     "smlal2 v10.4s, v26.8h, v3.8h\n"
-    "ldr d26, [x19, x4]\n"
-    "smlal v19.4s, v24.4h, v0.4h\n"
-    "ldr x21, [x7, #0xb0]\n"
-    "smlal2 v20.4s, v24.8h, v0.8h\n"
-    "ldr x20, [x7, #0xb8]\n"
-    "smlal v13.4s, v25.4h, v4.4h\n"
-    "ldr x19, [x7, #0xc0]\n"
+    "ldr d26, [x19, x15]\n"
+    "usubl v26.8h, v26.8b, v12.8b\n"
+    "smlal v9.4s, v24.4h, v0.4h\n"
+    "smlal2 v16.4s, v24.8h, v0.8h\n"
+    "ldr x21, [x12, #0x80]\n"
+    "ldr x19, [x12, #0x68]\n"
+    "smlal v15.4s, v25.4h, v4.4h\n"
     "smlal2 v10.4s, v25.8h, v4.8h\n"
-    "ldr d25, [x11, x4]\n"
-    "usubl v29.8h, v29.8b, v22.8b\n"
-    "ldr q31, [x8, #0x0]\n"
-    "usubl v28.8h, v28.8b, v22.8b\n"
-    "ldr q30, [x16, #0x0]\n"
-    "smlal v13.4s, v24.4h, v2.4h\n"
-    "ldr q23, [x8, #0x10]\n"
-    "add x8, x8, #0x20\n"
+    "ldr d25, [x20, x15]\n"
+    "usubl v25.8h, v25.8b, v12.8b\n"
+    "smlal v9.4s, v29.4h, v4.4h\n"
+    "smlal2 v16.4s, v29.8h, v4.8h\n"
+    "ldr x20, [x12, #0x88]\n"
+    "ldr d29, [x19, x15]\n"
+    "smlal v15.4s, v24.4h, v2.4h\n"
     "smlal2 v10.4s, v24.8h, v2.8h\n"
-    "ldr d24, [x9, x4]\n"
-    "smlal v19.4s, v29.4h, v4.4h\n"
-    "ldr q9, [x16, #0x10]\n"
-    "add x16, x16, #0x20\n"
-    "smlal2 v20.4s, v29.8h, v4.8h\n"
-    "ldr d29, [x10, x4]\n"
-    "usubl v27.8h, v27.8b, v22.8b\n"
-    "usubl v26.8h, v26.8b, v22.8b\n"
-    "smlal v19.4s, v28.4h, v5.4h\n"
-    "smlal v13.4s, v27.4h, v5.4h\n"
-    "smlal2 v20.4s, v28.8h, v5.8h\n"
-    "ldr d28, [x27, x4]\n"
+    "ldr x19, [x12, #0x70]\n"
+    "usubl v29.8h, v29.8b, v12.8b\n"
+    "smlal v9.4s, v28.4h, v5.4h\n"
+    "smlal2 v16.4s, v28.8h, v5.8h\n"
+    "ldr d28, [x21, x15]\n"
+    "usubl v28.8h, v28.8b, v12.8b\n"
+    "smlal v22.4s, v31.4h, v2.4h\n"
+    "smlal2 v21.4s, v31.8h, v2.8h\n"
+    "ldr x24, [x12, #0x98]\n"
+    "ldr d24, [x19, x15]\n"
+    "smlal v15.4s, v27.4h, v5.4h\n"
     "smlal2 v10.4s, v27.8h, v5.8h\n"
-    "smlal v19.4s, v27.4h, v3.4h\n"
-    "smlal v11.4s, v26.4h, v3.4h\n"
-    "smlal2 v20.4s, v27.8h, v3.8h\n"
-    "ldr d27, [x28, x4]\n"
-    "smlal2 v17.4s, v26.8h, v3.8h\n"
-    "ldr d26, [x26, x4]\n"
-    "usubl v25.8h, v25.8b, v22.8b\n"
-    "usubl v29.8h, v29.8b, v22.8b\n"
-    "usubl v24.8h, v24.8b, v22.8b\n"
-    "smlal v13.4s, v25.4h, v6.4h\n"
+    "usubl v24.8h, v24.8b, v12.8b\n"
+    "ldr x23, [x12, #0x90]\n"
+    "smlal v9.4s, v27.4h, v3.4h\n"
+    "smlal2 v16.4s, v27.8h, v3.8h\n"
+    "ldr d27, [x22, x15]\n"
+    "usubl v27.8h, v27.8b, v12.8b\n"
+    "smlal v23.4s, v31.4h, v0.4h\n"
+    "smlal v22.4s, v26.4h, v3.4h\n"
+    "ldr x22, [x12, #0xa8]\n"
+    "ldr x19, [x12, #0xa0]\n"
+    "smlal2 v21.4s, v26.8h, v3.8h\n"
+    "smlal2 v18.4s, v31.8h, v0.8h\n"
+    "ldr d26, [x20, x15]\n"
+    "usubl v26.8h, v26.8b, v12.8b\n"
+    "smlal v23.4s, v27.4h, v4.4h\n"
+    "smlal v22.4s, v25.4h, v0.4h\n"
+    "ldr x21, [x12, #0xb0]\n"
+    "ldr x20, [x12, #0xb8]\n"
+    "smlal2 v21.4s, v25.8h, v0.8h\n"
+    "smlal2 v18.4s, v27.8h, v4.8h\n"
+    "ldr d27, [x19, x15]\n"
+    "usubl v27.8h, v27.8b, v12.8b\n"
+    "smlal v23.4s, v28.4h, v1.4h\n"
+    "smlal v15.4s, v25.4h, v6.4h\n"
+    "ldr x19, [x12, #0xc0]\n"
+    "ldr q19, [x13, #0x0]\n"
     "smlal2 v10.4s, v25.8h, v6.8h\n"
-    "smlal v11.4s, v25.4h, v0.4h\n"
-    "smlal2 v17.4s, v25.8h, v0.8h\n"
-    "ldr d25, [x25, x4]\n"
-    "smlal v13.4s, v24.4h, v7.4h\n"
+    "smlal v22.4s, v29.4h, v4.4h\n"
+    "ldr d25, [x23, x15]\n"
+    "usubl v25.8h, v25.8b, v12.8b\n"
+    "smlal2 v21.4s, v29.8h, v4.8h\n"
+    "ldr d29, [x24, x15]\n"
+    "smlal2 v18.4s, v28.8h, v1.8h\n"
+    "usubl v29.8h, v29.8b, v12.8b\n"
+    "smlal v23.4s, v26.4h, v5.4h\n"
+    "smlal v15.4s, v24.4h, v7.4h\n"
+    "ldr q0, [x11, #0x0]\n"
+    "ldr q4, [x13, #0x10]\n"
     "smlal2 v10.4s, v24.8h, v7.8h\n"
-    "smlal v11.4s, v29.4h, v4.4h\n"
-    "smlal2 v17.4s, v29.8h, v4.8h\n"
-    "ldr d29, [x24, x4]\n"
-    "usubl v27.8h, v27.8b, v22.8b\n"
-    "usubl v28.8h, v28.8b, v22.8b\n"
-    "smlal v11.4s, v24.4h, v1.4h\n"
-    "smlal2 v17.4s, v24.8h, v1.8h\n"
-    "ldr d24, [x22, x4]\n"
-    "smlal v18.4s, v27.4h, v4.4h\n"
-    "smlal2 v21.4s, v27.8h, v4.8h\n"
-    "ldr d27, [x23, x4]\n"
-    "smlal v19.4s, v28.4h, v7.4h\n"
-    "smlal2 v20.4s, v28.8h, v7.8h\n"
-    "smlal v18.4s, v28.4h, v1.4h\n"
-    "smlal2 v21.4s, v28.8h, v1.8h\n"
-    "usubl v26.8h, v26.8b, v22.8b\n"
-    "usubl v25.8h, v25.8b, v22.8b\n"
-    "usubl v29.8h, v29.8b, v22.8b\n"
-    "smlal v18.4s, v26.4h, v5.4h\n"
-    "smlal2 v21.4s, v26.8h, v5.8h\n"
-    "ldr d26, [x21, x4]\n"
-    "smlal v11.4s, v25.4h, v6.4h\n"
-    "smlal2 v17.4s, v25.8h, v6.8h\n"
-    "ldr d25, [x20, x4]\n"
-    "smlal v19.4s, v29.4h, v8.4h\n"
-    "smlal2 v20.4s, v29.8h, v8.8h\n"
-    "smlal v18.4s, v29.4h, v2.4h\n"
-    "smlal2 v21.4s, v29.8h, v2.8h\n"
-    "ldr d29, [x19, x4]\n"
-    "add x4, x4, #0x8\n"
-    "usubl v27.8h, v27.8b, v22.8b\n"
-    "usubl v24.8h, v24.8b, v22.8b\n"
-    "usubl v26.8h, v26.8b, v22.8b\n"
-    "usubl v25.8h, v25.8b, v22.8b\n"
-    "smlal v11.4s, v27.4h, v7.4h\n"
-    "smlal2 v17.4s, v27.8h, v7.8h\n"
-    "smlal v18.4s, v24.4h, v3.4h\n"
-    "smlal2 v21.4s, v24.8h, v3.8h\n"
-    "smlal v11.4s, v24.4h, v5.4h\n"
-    "smlal2 v17.4s, v24.8h, v5.8h\n"
-    "smlal v18.4s, v26.4h, v7.4h\n"
-    "smlal2 v21.4s, v26.8h, v7.8h\n"
-    "smlal v11.4s, v25.4h, v8.4h\n"
-    "smlal2 v17.4s, v25.8h, v8.8h\n"
-    "smlal v18.4s, v25.4h, v6.4h\n"
+    "smlal v22.4s, v24.4h, v1.4h\n"
+    "sqdmulh v15.4s, v15.4s, v19.4s\n"
+    "ldr q31, [x11, #0x10]\n"
+    "smlal2 v21.4s, v24.8h, v1.8h\n"
+    "ldr d24, [x22, x15]\n"
+    "smlal2 v18.4s, v26.8h, v5.8h\n"
+    "usubl v24.8h, v24.8b, v12.8b\n"
+    "smlal v23.4s, v29.4h, v2.4h\n"
+    "ldr d26, [x21, x15]\n"
+    "smlal2 v18.4s, v29.8h, v2.8h\n"
+    "usubl v26.8h, v26.8b, v12.8b\n"
+    "smlal v22.4s, v25.4h, v6.4h\n"
+    "smlal v23.4s, v24.4h, v3.4h\n"
+    "and v30.16b, v15.16b, v0.16b\n"
+    "add x17, x17, #0x48\n"
+    "smlal v9.4s, v28.4h, v7.4h\n"
+    "smlal2 v16.4s, v28.8h, v7.8h\n"
+    "sqdmulh v10.4s, v10.4s, v4.4s\n"
+    "subs x16, x16, #0x1\n"
     "smlal2 v21.4s, v25.8h, v6.8h\n"
-    "usubl v29.8h, v29.8b, v22.8b\n"
-    "sqrdmulh v13.4s, v13.4s, v31.4s\n"
-    "sqrdmulh v10.4s, v10.4s, v23.4s\n"
-    "smlal v18.4s, v29.4h, v8.4h\n"
-    "smlal2 v21.4s, v29.8h, v8.8h\n"
-    "and v27.16b, v13.16b, v30.16b\n"
-    "and v7.16b, v10.16b, v9.16b\n"
-    "sqrdmulh v19.4s, v19.4s, v31.4s\n"
-    "sshr v27.4s, v27.4s, #0x1f\n"
-    "sshr v7.4s, v7.4s, #0x1f\n"
-    "sqrdmulh v20.4s, v20.4s, v23.4s\n"
-    "sqadd v13.4s, v13.4s, v27.4s\n"
-    "sqadd v10.4s, v10.4s, v7.4s\n"
-    "and v6.16b, v19.16b, v30.16b\n"
-    "and v3.16b, v20.16b, v9.16b\n"
-    "srshl v13.4s, v13.4s, v30.4s\n"
-    "srshl v10.4s, v10.4s, v9.4s\n"
-    "sshr v6.4s, v6.4s, #0x1f\n"
-    "sshr v3.4s, v3.4s, #0x1f\n"
-    "add v13.4s, v13.4s, v14.4s\n"
-    "add v10.4s, v10.4s, v14.4s\n"
-    "sqadd v19.4s, v19.4s, v6.4s\n"
-    "smin v13.4s, v13.4s, v15.4s\n"
-    "smin v10.4s, v10.4s, v15.4s\n"
-    "sqadd v20.4s, v20.4s, v3.4s\n"
-    "smax v13.4s, v13.4s, v16.4s\n"
-    "smax v10.4s, v10.4s, v16.4s\n"
-    "srshl v19.4s, v19.4s, v30.4s\n"
-    "srshl v20.4s, v20.4s, v9.4s\n"
-    "uzp1 v13.16b, v13.16b, v10.16b\n"
-    "sqrdmulh v11.4s, v11.4s, v31.4s\n"
-    "uzp1 v13.16b, v13.16b, v13.16b\n"
-    "str d13, [x15, x6]\n"
-    "add v19.4s, v19.4s, v14.4s\n"
-    "add v20.4s, v20.4s, v14.4s\n"
-    "and v28.16b, v11.16b, v30.16b\n"
-    "sqrdmulh v17.4s, v17.4s, v23.4s\n"
-    "smin v19.4s, v19.4s, v15.4s\n"
-    "smin v20.4s, v20.4s, v15.4s\n"
+    "ldr d25, [x20, x15]\n"
+    "smlal2 v18.4s, v24.8h, v3.8h\n"
+    "usubl v25.8h, v25.8b, v12.8b\n"
+    "smlal v22.4s, v27.4h, v7.4h\n"
+    "smlal v23.4s, v26.4h, v7.4h\n"
+    "sshr v30.4s, v30.4s, #0x1f\n"
+    "add x13, x13, #0x20\n"
+    "smlal v9.4s, v29.4h, v8.4h\n"
+    "smlal2 v16.4s, v29.8h, v8.8h\n"
+    "ldr d29, [x19, x15]\n"
+    "usubl v29.8h, v29.8b, v12.8b\n"
+    "smlal2 v21.4s, v27.8h, v7.8h\n"
+    "smlal2 v18.4s, v26.8h, v7.8h\n"
+    "sqdmulh v9.4s, v9.4s, v19.4s\n"
+    "add x15, x15, #0x8\n"
+    "smlal v22.4s, v24.4h, v5.4h\n"
+    "smlal v23.4s, v25.4h, v6.4h\n"
+    "and v28.16b, v9.16b, v0.16b\n"
+    "add x11, x11, #0x20\n"
+    "smlal2 v21.4s, v24.8h, v5.8h\n"
+    "smlal2 v18.4s, v25.8h, v6.8h\n"
+    "sqdmulh v16.4s, v16.4s, v4.4s\n"
+    "smlal v22.4s, v25.4h, v8.4h\n"
+    "smlal v23.4s, v29.4h, v8.4h\n"
+    "sqdmulh v22.4s, v22.4s, v19.4s\n"
+    "smlal2 v21.4s, v25.8h, v8.8h\n"
+    "smlal2 v18.4s, v29.8h, v8.8h\n"
+    "sqdmulh v23.4s, v23.4s, v19.4s\n"
+    "and v29.16b, v22.16b, v0.16b\n"
+    "sqdmulh v21.4s, v21.4s, v4.4s\n"
+    "and v20.16b, v23.16b, v0.16b\n"
+    "sqdmulh v18.4s, v18.4s, v4.4s\n"
+    "and v19.16b, v10.16b, v31.16b\n"
     "sshr v28.4s, v28.4s, #0x1f\n"
-    "smax v19.4s, v19.4s, v16.4s\n"
-    "smax v20.4s, v20.4s, v16.4s\n"
-    "sqadd v11.4s, v11.4s, v28.4s\n"
-    "and v26.16b, v17.16b, v9.16b\n"
-    "uzp1 v19.16b, v19.16b, v20.16b\n"
-    "sqrdmulh v18.4s, v18.4s, v31.4s\n"
-    "uzp1 v19.16b, v19.16b, v19.16b\n"
-    "str d19, [x14, x6]\n"
-    "srshl v11.4s, v11.4s, v30.4s\n"
+    "and v4.16b, v16.16b, v31.16b\n"
+    "sshr v29.4s, v29.4s, #0x1f\n"
+    "and v5.16b, v21.16b, v31.16b\n"
+    "sshr v20.4s, v20.4s, #0x1f\n"
+    "and v26.16b, v18.16b, v31.16b\n"
+    "sqadd v15.4s, v15.4s, v30.4s\n"
+    "sshr v19.4s, v19.4s, #0x1f\n"
+    "sqadd v9.4s, v9.4s, v28.4s\n"
+    "sshr v4.4s, v4.4s, #0x1f\n"
+    "sqadd v22.4s, v22.4s, v29.4s\n"
+    "sshr v5.4s, v5.4s, #0x1f\n"
+    "sqadd v23.4s, v23.4s, v20.4s\n"
     "sshr v26.4s, v26.4s, #0x1f\n"
-    "and v8.16b, v18.16b, v30.16b\n"
-    "sqrdmulh v21.4s, v21.4s, v23.4s\n"
-    "sqadd v17.4s, v17.4s, v26.4s\n"
-    "add v11.4s, v11.4s, v14.4s\n"
-    "sshr v8.4s, v8.4s, #0x1f\n"
-    "and v27.16b, v21.16b, v9.16b\n"
-    "smin v11.4s, v11.4s, v15.4s\n"
-    "srshl v17.4s, v17.4s, v9.4s\n"
-    "sqadd v18.4s, v18.4s, v8.4s\n"
-    "smax v11.4s, v11.4s, v16.4s\n"
-    "sshr v27.4s, v27.4s, #0x1f\n"
-    "add v17.4s, v17.4s, v14.4s\n"
-    "srshl v18.4s, v18.4s, v30.4s\n"
-    "sqadd v21.4s, v21.4s, v27.4s\n"
-    "smin v17.4s, v17.4s, v15.4s\n"
-    "add v18.4s, v18.4s, v14.4s\n"
-    "smax v17.4s, v17.4s, v16.4s\n"
-    "srshl v21.4s, v21.4s, v9.4s\n"
-    "smin v18.4s, v18.4s, v15.4s\n"
-    "uzp1 v11.16b, v11.16b, v17.16b\n"
-    "add v21.4s, v21.4s, v14.4s\n"
-    "uzp1 v11.16b, v11.16b, v11.16b\n"
-    "str d11, [x13, x6]\n"
-    "smax v18.4s, v18.4s, v16.4s\n"
-    "smin v21.4s, v21.4s, v15.4s\n"
-    "smax v21.4s, v21.4s, v16.4s\n"
-    "uzp1 v18.16b, v18.16b, v21.16b\n"
-    "uzp1 v18.16b, v18.16b, v18.16b\n"
-    "str d18, [x12, x6]\n"
-    "add x6, x6, #0x8\n"
+    "srshl v15.4s, v15.4s, v0.4s\n"
+    "sqadd v10.4s, v10.4s, v19.4s\n"
+    "srshl v9.4s, v9.4s, v0.4s\n"
+    "sqadd v16.4s, v16.4s, v4.4s\n"
+    "srshl v22.4s, v22.4s, v0.4s\n"
+    "sqadd v21.4s, v21.4s, v5.4s\n"
+    "srshl v23.4s, v23.4s, v0.4s\n"
+    "sqadd v18.4s, v18.4s, v26.4s\n"
+    "srshl v10.4s, v10.4s, v31.4s\n"
+    "sqxtn v15.4h, v15.4s\n"
+    "srshl v16.4s, v16.4s, v31.4s\n"
+    "sqxtn v9.4h, v9.4s\n"
+    "srshl v21.4s, v21.4s, v31.4s\n"
+    "sqxtn v22.4h, v22.4s\n"
+    "srshl v18.4s, v18.4s, v31.4s\n"
+    "sqxtn v23.4h, v23.4s\n"
+    "sqxtn2 v15.8h, v10.4s\n"
+    "sqxtn2 v9.8h, v16.4s\n"
+    "sqxtn2 v22.8h, v21.4s\n"
+    "sqxtn2 v23.8h, v18.4s\n"
+    "sqadd v15.8h, v15.8h, v11.8h\n"
+    "sqadd v9.8h, v9.8h, v11.8h\n"
+    "sqadd v22.8h, v22.8h, v11.8h\n"
+    "sqadd v23.8h, v23.8h, v11.8h\n"
+    "smax v15.8h, v15.8h, v17.8h\n"
+    "smax v9.8h, v9.8h, v17.8h\n"
+    "smax v22.8h, v22.8h, v17.8h\n"
+    "smax v23.8h, v23.8h, v17.8h\n"
+    "smin v15.8h, v15.8h, v14.8h\n"
+    "smin v9.8h, v9.8h, v14.8h\n"
+    "smin v22.8h, v22.8h, v14.8h\n"
+    "smin v23.8h, v23.8h, v14.8h\n"
+    "uzp1 v15.16b, v15.16b, v15.16b\n"
+    "str d15, [x10, x14]\n"
+    "uzp1 v9.16b, v9.16b, v9.16b\n"
+    "uzp1 v22.16b, v22.16b, v22.16b\n"
+    "str d9, [x9, x14]\n"
+    "uzp1 v23.16b, v23.16b, v23.16b\n"
+    "str d22, [x28, x14]\n"
+    "str d23, [x27, x14]\n"
     "ldr x19, [%x[params], %[offsetof_Params_bias]]\n"
-    "ldr q13, [x19, #0x0]\n"
-    "mov v19.16b, v13.16b\n"
+    "ldr q15, [x19, #0x0]\n"
+    "add x14, x14, #0x8\n"
     "ldr q10, [x19, #0x10]\n"
     "add x19, x19, #0x20\n"
-    "mov v11.16b, v13.16b\n"
     "str x19, [%x[params], %[offsetof_Params_bias]]\n"
-    "mov v18.16b, v13.16b\n"
-    "ldr d0, [x5, #0x0]\n"
-    "ldr d1, [x5, #0x8]\n"
-    "mov v20.16b, v10.16b\n"
-    "ldr d2, [x5, #0x10]\n"
-    "mov v17.16b, v10.16b\n"
-    "ldr d3, [x5, #0x18]\n"
+    "ldr d0, [x17, #0x0]\n"
+    "ldr d1, [x17, #0x8]\n"
+    "ldr d2, [x17, #0x10]\n"
+    "mov v9.16b, v15.16b\n"
+    "mov v16.16b, v10.16b\n"
+    "ldr d3, [x17, #0x18]\n"
+    "ldr d4, [x17, #0x20]\n"
+    "mov v22.16b, v15.16b\n"
     "mov v21.16b, v10.16b\n"
-    "ldr d4, [x5, #0x20]\n"
-    "ssubl v0.8h, v0.8b, v12.8b\n"
-    "ldr d5, [x5, #0x28]\n"
-    "ssubl v1.8h, v1.8b, v12.8b\n"
-    "ldr d6, [x5, #0x30]\n"
-    "ssubl v2.8h, v2.8b, v12.8b\n"
-    "ldr d7, [x5, #0x38]\n"
-    "ssubl v3.8h, v3.8b, v12.8b\n"
-    "ldr d8, [x5, #0x40]\n"
-    "ssubl v4.8h, v4.8b, v12.8b\n"
-    "ldp x26, x25, [x7, #0x0]\n"
-    "ssubl v5.8h, v5.8b, v12.8b\n"
-    "ldp x24, x23, [x7, #0x10]\n"
-    "ssubl v6.8h, v6.8b, v12.8b\n"
-    "ssubl v7.8h, v7.8b, v12.8b\n"
-    "ldp x22, x21, [x7, #0x20]\n"
-    "ssubl v8.8h, v8.8b, v12.8b\n"
-    "ldp x20, x19, [x7, #0x30]\n"
-    "ldr d31, [x26, x4]\n"
-    "usubl v31.8h, v31.8b, v22.8b\n"
-    "ldr d30, [x25, x4]\n"
-    "ldr d29, [x24, x4]\n"
-    "usubl v30.8h, v30.8b, v22.8b\n"
-    "ldr d28, [x23, x4]\n"
-    "usubl v29.8h, v29.8b, v22.8b\n"
-    "ldr d27, [x22, x4]\n"
-    "ldr d26, [x21, x4]\n"
-    "usubl v28.8h, v28.8b, v22.8b\n"
-    "ldr d25, [x20, x4]\n"
-    "ldr d24, [x19, x4]\n"
-    "usubl v27.8h, v27.8b, v22.8b\n"
-    "usubl v26.8h, v26.8b, v22.8b\n"
-    "usubl v25.8h, v25.8b, v22.8b\n"
-    "usubl v24.8h, v24.8b, v22.8b\n"
+    "ldr d5, [x17, #0x28]\n"
+    "ldr d6, [x17, #0x30]\n"
+    "mov v23.16b, v15.16b\n"
+    "mov v18.16b, v10.16b\n"
+    "ldr d7, [x17, #0x38]\n"
+    "ldr d8, [x17, #0x40]\n"
+    "ssubl v0.8h, v0.8b, v13.8b\n"
+    "ssubl v1.8h, v1.8b, v13.8b\n"
+    "ldp x26, x25, [x12, #0x0]\n"
+    "ldp x24, x23, [x12, #0x10]\n"
+    "ssubl v2.8h, v2.8b, v13.8b\n"
+    "ssubl v3.8h, v3.8b, v13.8b\n"
+    "ldp x22, x21, [x12, #0x20]\n"
+    "ldp x20, x19, [x12, #0x30]\n"
+    "ssubl v4.8h, v4.8b, v13.8b\n"
+    "ssubl v5.8h, v5.8b, v13.8b\n"
+    "ldr d31, [x26, x15]\n"
+    "ldr d30, [x25, x15]\n"
+    "ssubl v6.8h, v6.8b, v13.8b\n"
+    "ssubl v7.8h, v7.8b, v13.8b\n"
+    "ldr d29, [x24, x15]\n"
+    "ldr d28, [x23, x15]\n"
+    "ssubl v8.8h, v8.8b, v13.8b\n"
+    "usubl v31.8h, v31.8b, v12.8b\n"
+    "ldr d27, [x22, x15]\n"
+    "ldr d26, [x21, x15]\n"
+    "usubl v30.8h, v30.8b, v12.8b\n"
+    "usubl v29.8h, v29.8b, v12.8b\n"
+    "ldr d25, [x20, x15]\n"
+    "ldr d24, [x19, x15]\n"
+    "usubl v28.8h, v28.8b, v12.8b\n"
+    "usubl v27.8h, v27.8b, v12.8b\n"
+    "usubl v26.8h, v26.8b, v12.8b\n"
+    "usubl v25.8h, v25.8b, v12.8b\n"
+    "usubl v24.8h, v24.8b, v12.8b\n"
     "bgt 1b\n"
     "2:"  // Tail
-    "smlal v13.4s, v31.4h, v8.4h\n"
-    "ldr x22, [x7, #0x40]\n"
-    "tst x3, #0x7\n"
+    "smlal v15.4s, v31.4h, v8.4h\n"
     "smlal2 v10.4s, v31.8h, v8.8h\n"
-    "ldr x21, [x7, #0x48]\n"
-    "smlal v19.4s, v31.4h, v6.4h\n"
-    "ldr x20, [x7, #0x50]\n"
-    "smlal2 v20.4s, v31.8h, v6.8h\n"
-    "ldr x19, [x7, #0x58]\n"
-    "smlal v11.4s, v31.4h, v2.4h\n"
-    "ldr x11, [x7, #0x60]\n"
-    "smlal2 v17.4s, v31.8h, v2.8h\n"
-    "ldr x10, [x7, #0x68]\n"
-    "smlal v18.4s, v31.4h, v0.4h\n"
-    "ldr x9, [x7, #0x70]\n"
-    "smlal2 v21.4s, v31.8h, v0.8h\n"
-    "ldr x28, [x7, #0x78]\n"
-    "smlal v13.4s, v30.4h, v0.4h\n"
-    "ldr x27, [x7, #0x80]\n"
+    "ldr x24, [x12, #0x40]\n"
+    "ldr x23, [x12, #0x48]\n"
+    "smlal v9.4s, v31.4h, v6.4h\n"
+    "smlal2 v16.4s, v31.8h, v6.8h\n"
+    "ldr x21, [x12, #0x50]\n"
+    "ldr x19, [x12, #0x58]\n"
+    "smlal v15.4s, v30.4h, v0.4h\n"
     "smlal2 v10.4s, v30.8h, v0.8h\n"
-    "ldr x26, [x7, #0x88]\n"
-    "smlal v19.4s, v28.4h, v1.4h\n"
-    "ldr x25, [x7, #0x90]\n"
-    "smlal2 v20.4s, v28.8h, v1.8h\n"
-    "ldr d28, [x21, x4]\n"
-    "smlal v13.4s, v29.4h, v1.4h\n"
-    "ldr x24, [x7, #0x98]\n"
+    "ldr x22, [x12, #0x78]\n"
+    "ldr x20, [x12, #0x60]\n"
+    "smlal v9.4s, v28.4h, v1.4h\n"
+    "smlal2 v16.4s, v28.8h, v1.8h\n"
+    "ldr d28, [x23, x15]\n"
+    "usubl v28.8h, v28.8b, v12.8b\n"
+    "smlal v15.4s, v29.4h, v1.4h\n"
     "smlal2 v10.4s, v29.8h, v1.8h\n"
-    "ldr d29, [x22, x4]\n"
-    "smlal v19.4s, v27.4h, v2.4h\n"
-    "ldr x23, [x7, #0xa0]\n"
-    "smlal2 v20.4s, v27.8h, v2.8h\n"
-    "ldr d27, [x20, x4]\n"
-    "smlal v13.4s, v26.4h, v3.4h\n"
-    "ldr x22, [x7, #0xa8]\n"
+    "ldr d29, [x24, x15]\n"
+    "usubl v29.8h, v29.8b, v12.8b\n"
+    "smlal v9.4s, v27.4h, v2.4h\n"
+    "smlal2 v16.4s, v27.8h, v2.8h\n"
+    "ldr d27, [x21, x15]\n"
+    "usubl v27.8h, v27.8b, v12.8b\n"
+    "smlal v15.4s, v26.4h, v3.4h\n"
     "smlal2 v10.4s, v26.8h, v3.8h\n"
-    "ldr d26, [x19, x4]\n"
-    "smlal v19.4s, v24.4h, v0.4h\n"
-    "ldr x21, [x7, #0xb0]\n"
-    "smlal2 v20.4s, v24.8h, v0.8h\n"
-    "ldr x20, [x7, #0xb8]\n"
-    "smlal v13.4s, v25.4h, v4.4h\n"
-    "ldr x19, [x7, #0xc0]\n"
+    "ldr d26, [x19, x15]\n"
+    "usubl v26.8h, v26.8b, v12.8b\n"
+    "smlal v9.4s, v24.4h, v0.4h\n"
+    "smlal2 v16.4s, v24.8h, v0.8h\n"
+    "ldr x21, [x12, #0x80]\n"
+    "ldr x19, [x12, #0x68]\n"
+    "smlal v15.4s, v25.4h, v4.4h\n"
     "smlal2 v10.4s, v25.8h, v4.8h\n"
-    "ldr d25, [x11, x4]\n"
-    "usubl v29.8h, v29.8b, v22.8b\n"
-    "ldr q31, [x8, #0x0]\n"
-    "usubl v28.8h, v28.8b, v22.8b\n"
-    "ldr q30, [x16, #0x0]\n"
-    "smlal v13.4s, v24.4h, v2.4h\n"
-    "ldr q23, [x8, #0x10]\n"
-    "add x8, x8, #0x20\n"
+    "ldr d25, [x20, x15]\n"
+    "usubl v25.8h, v25.8b, v12.8b\n"
+    "smlal v9.4s, v29.4h, v4.4h\n"
+    "smlal2 v16.4s, v29.8h, v4.8h\n"
+    "ldr x20, [x12, #0x88]\n"
+    "ldr d29, [x19, x15]\n"
+    "smlal v15.4s, v24.4h, v2.4h\n"
     "smlal2 v10.4s, v24.8h, v2.8h\n"
-    "ldr d24, [x9, x4]\n"
-    "smlal v19.4s, v29.4h, v4.4h\n"
-    "ldr q9, [x16, #0x10]\n"
-    "add x16, x16, #0x20\n"
-    "smlal2 v20.4s, v29.8h, v4.8h\n"
-    "ldr d29, [x10, x4]\n"
-    "usubl v27.8h, v27.8b, v22.8b\n"
-    "usubl v26.8h, v26.8b, v22.8b\n"
-    "smlal v19.4s, v28.4h, v5.4h\n"
-    "smlal v13.4s, v27.4h, v5.4h\n"
-    "smlal2 v20.4s, v28.8h, v5.8h\n"
-    "ldr d28, [x27, x4]\n"
+    "ldr x19, [x12, #0x70]\n"
+    "usubl v29.8h, v29.8b, v12.8b\n"
+    "smlal v9.4s, v28.4h, v5.4h\n"
+    "smlal2 v16.4s, v28.8h, v5.8h\n"
+    "ldr d28, [x21, x15]\n"
+    "usubl v28.8h, v28.8b, v12.8b\n"
+    "smlal v22.4s, v31.4h, v2.4h\n"
+    "smlal2 v21.4s, v31.8h, v2.8h\n"
+    "ldr x24, [x12, #0x98]\n"
+    "ldr d24, [x19, x15]\n"
+    "smlal v15.4s, v27.4h, v5.4h\n"
     "smlal2 v10.4s, v27.8h, v5.8h\n"
-    "smlal v19.4s, v27.4h, v3.4h\n"
-    "smlal v11.4s, v26.4h, v3.4h\n"
-    "smlal2 v20.4s, v27.8h, v3.8h\n"
-    "ldr d27, [x28, x4]\n"
-    "smlal2 v17.4s, v26.8h, v3.8h\n"
-    "ldr d26, [x26, x4]\n"
-    "usubl v25.8h, v25.8b, v22.8b\n"
-    "usubl v29.8h, v29.8b, v22.8b\n"
-    "usubl v24.8h, v24.8b, v22.8b\n"
-    "smlal v13.4s, v25.4h, v6.4h\n"
+    "usubl v24.8h, v24.8b, v12.8b\n"
+    "ldr x23, [x12, #0x90]\n"
+    "smlal v9.4s, v27.4h, v3.4h\n"
+    "smlal2 v16.4s, v27.8h, v3.8h\n"
+    "ldr d27, [x22, x15]\n"
+    "usubl v27.8h, v27.8b, v12.8b\n"
+    "smlal v23.4s, v31.4h, v0.4h\n"
+    "smlal v22.4s, v26.4h, v3.4h\n"
+    "ldr x22, [x12, #0xa8]\n"
+    "ldr x19, [x12, #0xa0]\n"
+    "smlal2 v21.4s, v26.8h, v3.8h\n"
+    "smlal2 v18.4s, v31.8h, v0.8h\n"
+    "ldr d26, [x20, x15]\n"
+    "usubl v26.8h, v26.8b, v12.8b\n"
+    "smlal v23.4s, v27.4h, v4.4h\n"
+    "smlal v22.4s, v25.4h, v0.4h\n"
+    "ldr x21, [x12, #0xb0]\n"
+    "ldr x20, [x12, #0xb8]\n"
+    "smlal2 v21.4s, v25.8h, v0.8h\n"
+    "smlal2 v18.4s, v27.8h, v4.8h\n"
+    "ldr d27, [x19, x15]\n"
+    "usubl v27.8h, v27.8b, v12.8b\n"
+    "smlal v23.4s, v28.4h, v1.4h\n"
+    "smlal v15.4s, v25.4h, v6.4h\n"
+    "ldr x19, [x12, #0xc0]\n"
+    "ldr q19, [x13, #0x0]\n"
     "smlal2 v10.4s, v25.8h, v6.8h\n"
-    "smlal v11.4s, v25.4h, v0.4h\n"
-    "smlal2 v17.4s, v25.8h, v0.8h\n"
-    "ldr d25, [x25, x4]\n"
-    "smlal v13.4s, v24.4h, v7.4h\n"
+    "smlal v22.4s, v29.4h, v4.4h\n"
+    "ldr d25, [x23, x15]\n"
+    "usubl v25.8h, v25.8b, v12.8b\n"
+    "smlal2 v21.4s, v29.8h, v4.8h\n"
+    "ldr d29, [x24, x15]\n"
+    "smlal2 v18.4s, v28.8h, v1.8h\n"
+    "usubl v29.8h, v29.8b, v12.8b\n"
+    "smlal v23.4s, v26.4h, v5.4h\n"
+    "smlal v15.4s, v24.4h, v7.4h\n"
+    "ldr q0, [x11, #0x0]\n"
+    "ldr q4, [x13, #0x10]\n"
     "smlal2 v10.4s, v24.8h, v7.8h\n"
-    "smlal v11.4s, v29.4h, v4.4h\n"
-    "smlal2 v17.4s, v29.8h, v4.8h\n"
-    "ldr d29, [x24, x4]\n"
-    "usubl v27.8h, v27.8b, v22.8b\n"
-    "usubl v28.8h, v28.8b, v22.8b\n"
-    "smlal v11.4s, v24.4h, v1.4h\n"
-    "smlal2 v17.4s, v24.8h, v1.8h\n"
-    "ldr d24, [x22, x4]\n"
-    "smlal v18.4s, v27.4h, v4.4h\n"
-    "smlal2 v21.4s, v27.8h, v4.8h\n"
-    "ldr d27, [x23, x4]\n"
-    "smlal v19.4s, v28.4h, v7.4h\n"
-    "smlal2 v20.4s, v28.8h, v7.8h\n"
-    "smlal v18.4s, v28.4h, v1.4h\n"
-    "smlal2 v21.4s, v28.8h, v1.8h\n"
-    "usubl v26.8h, v26.8b, v22.8b\n"
-    "usubl v25.8h, v25.8b, v22.8b\n"
-    "usubl v29.8h, v29.8b, v22.8b\n"
-    "smlal v18.4s, v26.4h, v5.4h\n"
-    "smlal2 v21.4s, v26.8h, v5.8h\n"
-    "ldr d26, [x21, x4]\n"
-    "smlal v11.4s, v25.4h, v6.4h\n"
-    "smlal2 v17.4s, v25.8h, v6.8h\n"
-    "ldr d25, [x20, x4]\n"
-    "smlal v19.4s, v29.4h, v8.4h\n"
-    "smlal2 v20.4s, v29.8h, v8.8h\n"
-    "smlal v18.4s, v29.4h, v2.4h\n"
-    "smlal2 v21.4s, v29.8h, v2.8h\n"
-    "ldr d29, [x19, x4]\n"
-    "add x4, x4, #0x8\n"
-    "usubl v27.8h, v27.8b, v22.8b\n"
-    "usubl v24.8h, v24.8b, v22.8b\n"
-    "usubl v26.8h, v26.8b, v22.8b\n"
-    "usubl v25.8h, v25.8b, v22.8b\n"
-    "smlal v11.4s, v27.4h, v7.4h\n"
-    "smlal2 v17.4s, v27.8h, v7.8h\n"
-    "smlal v18.4s, v24.4h, v3.4h\n"
-    "smlal2 v21.4s, v24.8h, v3.8h\n"
-    "smlal v11.4s, v24.4h, v5.4h\n"
-    "smlal2 v17.4s, v24.8h, v5.8h\n"
-    "smlal v18.4s, v26.4h, v7.4h\n"
-    "smlal2 v21.4s, v26.8h, v7.8h\n"
-    "smlal v11.4s, v25.4h, v8.4h\n"
-    "smlal2 v17.4s, v25.8h, v8.8h\n"
-    "smlal v18.4s, v25.4h, v6.4h\n"
+    "smlal v22.4s, v24.4h, v1.4h\n"
+    "sqdmulh v15.4s, v15.4s, v19.4s\n"
+    "ldr q31, [x11, #0x10]\n"
+    "smlal2 v21.4s, v24.8h, v1.8h\n"
+    "ldr d24, [x22, x15]\n"
+    "smlal2 v18.4s, v26.8h, v5.8h\n"
+    "usubl v24.8h, v24.8b, v12.8b\n"
+    "smlal v23.4s, v29.4h, v2.4h\n"
+    "ldr d26, [x21, x15]\n"
+    "smlal2 v18.4s, v29.8h, v2.8h\n"
+    "usubl v26.8h, v26.8b, v12.8b\n"
+    "smlal v22.4s, v25.4h, v6.4h\n"
+    "smlal v23.4s, v24.4h, v3.4h\n"
+    "and v30.16b, v15.16b, v0.16b\n"
+    "tst x8, #0x7\n"
+    "smlal v9.4s, v28.4h, v7.4h\n"
+    "smlal2 v16.4s, v28.8h, v7.8h\n"
+    "sqdmulh v10.4s, v10.4s, v4.4s\n"
+    "add x13, x13, #0x20\n"
     "smlal2 v21.4s, v25.8h, v6.8h\n"
-    "usubl v29.8h, v29.8b, v22.8b\n"
-    "sqrdmulh v13.4s, v13.4s, v31.4s\n"
-    "sqrdmulh v10.4s, v10.4s, v23.4s\n"
-    "smlal v18.4s, v29.4h, v8.4h\n"
-    "smlal2 v21.4s, v29.8h, v8.8h\n"
-    "and v27.16b, v13.16b, v30.16b\n"
-    "and v7.16b, v10.16b, v9.16b\n"
-    "sqrdmulh v19.4s, v19.4s, v31.4s\n"
-    "sshr v27.4s, v27.4s, #0x1f\n"
-    "sshr v7.4s, v7.4s, #0x1f\n"
-    "sqrdmulh v20.4s, v20.4s, v23.4s\n"
-    "sqadd v13.4s, v13.4s, v27.4s\n"
-    "sqadd v10.4s, v10.4s, v7.4s\n"
-    "and v6.16b, v19.16b, v30.16b\n"
-    "and v3.16b, v20.16b, v9.16b\n"
-    "srshl v13.4s, v13.4s, v30.4s\n"
-    "srshl v10.4s, v10.4s, v9.4s\n"
-    "sshr v6.4s, v6.4s, #0x1f\n"
-    "sshr v3.4s, v3.4s, #0x1f\n"
-    "add v13.4s, v13.4s, v14.4s\n"
-    "add v10.4s, v10.4s, v14.4s\n"
-    "sqadd v19.4s, v19.4s, v6.4s\n"
-    "smin v13.4s, v13.4s, v15.4s\n"
-    "smin v10.4s, v10.4s, v15.4s\n"
-    "sqadd v20.4s, v20.4s, v3.4s\n"
-    "smax v13.4s, v13.4s, v16.4s\n"
-    "smax v10.4s, v10.4s, v16.4s\n"
-    "srshl v19.4s, v19.4s, v30.4s\n"
-    "srshl v20.4s, v20.4s, v9.4s\n"
-    "uzp1 v13.16b, v13.16b, v10.16b\n"
-    "sqrdmulh v11.4s, v11.4s, v31.4s\n"
-    "uzp1 v13.16b, v13.16b, v13.16b\n"
-    "str d13, [x15, x6]\n"
-    "add v19.4s, v19.4s, v14.4s\n"
-    "add v20.4s, v20.4s, v14.4s\n"
-    "and v28.16b, v11.16b, v30.16b\n"
-    "sqrdmulh v17.4s, v17.4s, v23.4s\n"
-    "smin v19.4s, v19.4s, v15.4s\n"
-    "smin v20.4s, v20.4s, v15.4s\n"
+    "ldr d25, [x20, x15]\n"
+    "smlal2 v18.4s, v24.8h, v3.8h\n"
+    "usubl v25.8h, v25.8b, v12.8b\n"
+    "smlal v22.4s, v27.4h, v7.4h\n"
+    "smlal v23.4s, v26.4h, v7.4h\n"
+    "sshr v30.4s, v30.4s, #0x1f\n"
+    "add x11, x11, #0x20\n"
+    "smlal v9.4s, v29.4h, v8.4h\n"
+    "smlal2 v16.4s, v29.8h, v8.8h\n"
+    "ldr d29, [x19, x15]\n"
+    "usubl v29.8h, v29.8b, v12.8b\n"
+    "smlal2 v21.4s, v27.8h, v7.8h\n"
+    "smlal2 v18.4s, v26.8h, v7.8h\n"
+    "sqdmulh v9.4s, v9.4s, v19.4s\n"
+    "add x15, x15, #0x8\n"
+    "smlal v22.4s, v24.4h, v5.4h\n"
+    "smlal v23.4s, v25.4h, v6.4h\n"
+    "and v28.16b, v9.16b, v0.16b\n"
+    "smlal2 v21.4s, v24.8h, v5.8h\n"
+    "smlal2 v18.4s, v25.8h, v6.8h\n"
+    "sqdmulh v16.4s, v16.4s, v4.4s\n"
+    "smlal v22.4s, v25.4h, v8.4h\n"
+    "smlal v23.4s, v29.4h, v8.4h\n"
+    "sqdmulh v22.4s, v22.4s, v19.4s\n"
+    "smlal2 v21.4s, v25.8h, v8.8h\n"
+    "smlal2 v18.4s, v29.8h, v8.8h\n"
+    "sqdmulh v23.4s, v23.4s, v19.4s\n"
+    "and v29.16b, v22.16b, v0.16b\n"
+    "sqdmulh v21.4s, v21.4s, v4.4s\n"
+    "and v20.16b, v23.16b, v0.16b\n"
+    "sqdmulh v18.4s, v18.4s, v4.4s\n"
+    "and v19.16b, v10.16b, v31.16b\n"
     "sshr v28.4s, v28.4s, #0x1f\n"
-    "smax v19.4s, v19.4s, v16.4s\n"
-    "smax v20.4s, v20.4s, v16.4s\n"
-    "sqadd v11.4s, v11.4s, v28.4s\n"
-    "and v26.16b, v17.16b, v9.16b\n"
-    "uzp1 v19.16b, v19.16b, v20.16b\n"
-    "sqrdmulh v18.4s, v18.4s, v31.4s\n"
-    "uzp1 v19.16b, v19.16b, v19.16b\n"
-    "str d19, [x14, x6]\n"
-    "srshl v11.4s, v11.4s, v30.4s\n"
+    "and v4.16b, v16.16b, v31.16b\n"
+    "sshr v29.4s, v29.4s, #0x1f\n"
+    "and v5.16b, v21.16b, v31.16b\n"
+    "sshr v20.4s, v20.4s, #0x1f\n"
+    "and v26.16b, v18.16b, v31.16b\n"
+    "sqadd v15.4s, v15.4s, v30.4s\n"
+    "sshr v19.4s, v19.4s, #0x1f\n"
+    "sqadd v9.4s, v9.4s, v28.4s\n"
+    "sshr v4.4s, v4.4s, #0x1f\n"
+    "sqadd v22.4s, v22.4s, v29.4s\n"
+    "sshr v5.4s, v5.4s, #0x1f\n"
+    "sqadd v23.4s, v23.4s, v20.4s\n"
     "sshr v26.4s, v26.4s, #0x1f\n"
-    "and v8.16b, v18.16b, v30.16b\n"
-    "sqrdmulh v21.4s, v21.4s, v23.4s\n"
-    "sqadd v17.4s, v17.4s, v26.4s\n"
-    "add v11.4s, v11.4s, v14.4s\n"
-    "sshr v8.4s, v8.4s, #0x1f\n"
-    "and v27.16b, v21.16b, v9.16b\n"
-    "smin v11.4s, v11.4s, v15.4s\n"
-    "srshl v17.4s, v17.4s, v9.4s\n"
-    "sqadd v18.4s, v18.4s, v8.4s\n"
-    "smax v11.4s, v11.4s, v16.4s\n"
-    "sshr v27.4s, v27.4s, #0x1f\n"
-    "add v17.4s, v17.4s, v14.4s\n"
-    "srshl v18.4s, v18.4s, v30.4s\n"
-    "sqadd v21.4s, v21.4s, v27.4s\n"
-    "smin v17.4s, v17.4s, v15.4s\n"
-    "add v18.4s, v18.4s, v14.4s\n"
-    "smax v17.4s, v17.4s, v16.4s\n"
-    "srshl v21.4s, v21.4s, v9.4s\n"
-    "smin v18.4s, v18.4s, v15.4s\n"
-    "uzp1 v11.16b, v11.16b, v17.16b\n"
-    "add v21.4s, v21.4s, v14.4s\n"
-    "uzp1 v11.16b, v11.16b, v11.16b\n"
-    "str d11, [x13, x6]\n"
-    "smax v18.4s, v18.4s, v16.4s\n"
-    "smin v21.4s, v21.4s, v15.4s\n"
-    "smax v21.4s, v21.4s, v16.4s\n"
-    "uzp1 v18.16b, v18.16b, v21.16b\n"
-    "uzp1 v18.16b, v18.16b, v18.16b\n"
-    "str d18, [x12, x6]\n"
-    "add x6, x6, #0x8\n"
+    "srshl v15.4s, v15.4s, v0.4s\n"
+    "sqadd v10.4s, v10.4s, v19.4s\n"
+    "srshl v9.4s, v9.4s, v0.4s\n"
+    "sqadd v16.4s, v16.4s, v4.4s\n"
+    "srshl v22.4s, v22.4s, v0.4s\n"
+    "sqadd v21.4s, v21.4s, v5.4s\n"
+    "srshl v23.4s, v23.4s, v0.4s\n"
+    "sqadd v18.4s, v18.4s, v26.4s\n"
+    "srshl v10.4s, v10.4s, v31.4s\n"
+    "sqxtn v15.4h, v15.4s\n"
+    "srshl v16.4s, v16.4s, v31.4s\n"
+    "sqxtn v9.4h, v9.4s\n"
+    "srshl v21.4s, v21.4s, v31.4s\n"
+    "sqxtn v22.4h, v22.4s\n"
+    "srshl v18.4s, v18.4s, v31.4s\n"
+    "sqxtn v23.4h, v23.4s\n"
+    "sqxtn2 v15.8h, v10.4s\n"
+    "sqxtn2 v9.8h, v16.4s\n"
+    "sqxtn2 v22.8h, v21.4s\n"
+    "sqxtn2 v23.8h, v18.4s\n"
+    "sqadd v15.8h, v15.8h, v11.8h\n"
+    "sqadd v9.8h, v9.8h, v11.8h\n"
+    "sqadd v22.8h, v22.8h, v11.8h\n"
+    "sqadd v23.8h, v23.8h, v11.8h\n"
+    "smax v15.8h, v15.8h, v17.8h\n"
+    "smax v9.8h, v9.8h, v17.8h\n"
+    "smax v22.8h, v22.8h, v17.8h\n"
+    "smax v23.8h, v23.8h, v17.8h\n"
+    "smin v15.8h, v15.8h, v14.8h\n"
+    "smin v9.8h, v9.8h, v14.8h\n"
+    "smin v22.8h, v22.8h, v14.8h\n"
+    "smin v23.8h, v23.8h, v14.8h\n"
+    "uzp1 v15.16b, v15.16b, v15.16b\n"
+    "str d15, [x10, x14]\n"
+    "uzp1 v9.16b, v9.16b, v9.16b\n"
+    "uzp1 v22.16b, v22.16b, v22.16b\n"
+    "str d9, [x9, x14]\n"
+    "uzp1 v23.16b, v23.16b, v23.16b\n"
+    "str d22, [x28, x14]\n"
+    "str d23, [x27, x14]\n"
+    "add x14, x14, #0x8\n"
     "beq 88f\n"
-    "add x5, x5, #0x48\n"
+    "add x17, x17, #0x48\n"
     "3:"  // Oddments
     "ldr x19, [%x[params], %[offsetof_Params_bias]]\n"
-    "tbz x3, #2, 5f\n"
-    "ld1 { v13.4s }, [x19], #0x10\n"
-    "tbz x3, #1, 4f\n"
+    "tbz x8, #2, 5f\n"
+    "ld1 { v15.4s }, [x19], #0x10\n"
+    "tbz x8, #1, 4f\n"
     "ld1 { v10.d }[0], [x19], #0x8\n"
-    "tbz x3, #0, 7f\n"
+    "tbz x8, #0, 7f\n"
     "ld1 { v10.s }[2], [x19]\n"
     "b 7f\n"
     "4:"  // Oddments: Load bias: Bit 2: Bit 1: Unset
-    "tbz x3, #0, 7f\n"
+    "tbz x8, #0, 7f\n"
     "ld1 { v10.s }[0], [x19]\n"
     "b 7f\n"
     "5:"  // Oddments: Load bias: Bit 2: Unset
-    "tbz x3, #1, 6f\n"
-    "ld1 { v13.d }[0], [x19], #0x8\n"
-    "tbz x3, #0, 7f\n"
-    "ld1 { v13.s }[2], [x19]\n"
+    "tbz x8, #1, 6f\n"
+    "ld1 { v15.d }[0], [x19], #0x8\n"
+    "tbz x8, #0, 7f\n"
+    "ld1 { v15.s }[2], [x19]\n"
     "b 7f\n"
     "6:"  // Oddments: Load bias: Bit 2: Unset: Bit 1: Unset
-    "tbz x3, #0, 7f\n"
-    "ld1 { v13.s }[0], [x19]\n"
+    "tbz x8, #0, 7f\n"
+    "ld1 { v15.s }[0], [x19]\n"
     "7:"  // Oddments: Load bias: Bit 2: End
-    "mov v19.16b, v13.16b\n"
-    "ldr d0, [x5, #0x0]\n"
-    "mov v20.16b, v10.16b\n"
-    "ldr d1, [x5, #0x8]\n"
-    "mov v11.16b, v13.16b\n"
-    "ldr d2, [x5, #0x10]\n"
-    "mov v17.16b, v10.16b\n"
-    "ldr d3, [x5, #0x18]\n"
-    "mov v18.16b, v13.16b\n"
-    "ldr d4, [x5, #0x20]\n"
+    "ldr d0, [x17, #0x0]\n"
+    "ldr d1, [x17, #0x8]\n"
+    "mov v9.16b, v15.16b\n"
+    "mov v16.16b, v10.16b\n"
+    "ldr d2, [x17, #0x10]\n"
+    "ldr d3, [x17, #0x18]\n"
+    "mov v22.16b, v15.16b\n"
     "mov v21.16b, v10.16b\n"
-    "ldr d5, [x5, #0x28]\n"
-    "ssubl v0.8h, v0.8b, v12.8b\n"
-    "ldr d6, [x5, #0x30]\n"
-    "ssubl v1.8h, v1.8b, v12.8b\n"
-    "ldr d7, [x5, #0x38]\n"
-    "ssubl v2.8h, v2.8b, v12.8b\n"
-    "ldr d8, [x5, #0x40]\n"
-    "ssubl v3.8h, v3.8b, v12.8b\n"
-    "ldp x26, x25, [x7, #0x0]\n"
-    "add x26, x26, x4\n"
-    "ssubl v4.8h, v4.8b, v12.8b\n"
-    "ldp x24, x23, [x7, #0x10]\n"
-    "ssubl v5.8h, v5.8b, v12.8b\n"
-    "ldp x22, x21, [x7, #0x20]\n"
-    "ssubl v6.8h, v6.8b, v12.8b\n"
-    "add x25, x25, x4\n"
-    "ssubl v7.8h, v7.8b, v12.8b\n"
-    "ldp x20, x19, [x7, #0x30]\n"
-    "ssubl v8.8h, v8.8b, v12.8b\n"
-    "add x24, x24, x4\n"
-    "add x23, x23, x4\n"
-    "add x22, x22, x4\n"
-    "add x21, x21, x4\n"
-    "add x20, x20, x4\n"
-    "add x19, x19, x4\n"
-    "tbz x3, #2, 9f\n"
+    "ldr d4, [x17, #0x20]\n"
+    "ldr d5, [x17, #0x28]\n"
+    "mov v23.16b, v15.16b\n"
+    "mov v18.16b, v10.16b\n"
+    "ldr d6, [x17, #0x30]\n"
+    "ldr d7, [x17, #0x38]\n"
+    "ssubl v0.8h, v0.8b, v13.8b\n"
+    "ssubl v1.8h, v1.8b, v13.8b\n"
+    "ldr d8, [x17, #0x40]\n"
+    "ldp x26, x25, [x12, #0x0]\n"
+    "ssubl v2.8h, v2.8b, v13.8b\n"
+    "ssubl v3.8h, v3.8b, v13.8b\n"
+    "ldp x24, x23, [x12, #0x10]\n"
+    "ldp x22, x21, [x12, #0x20]\n"
+    "ssubl v4.8h, v4.8b, v13.8b\n"
+    "ssubl v5.8h, v5.8b, v13.8b\n"
+    "ldp x20, x19, [x12, #0x30]\n"
+    "ssubl v6.8h, v6.8b, v13.8b\n"
+    "ssubl v7.8h, v7.8b, v13.8b\n"
+    "ssubl v8.8h, v8.8b, v13.8b\n"
+    "add x26, x26, x15\n"
+    "add x25, x25, x15\n"
+    "add x24, x24, x15\n"
+    "add x23, x23, x15\n"
+    "add x22, x22, x15\n"
+    "add x21, x21, x15\n"
+    "add x20, x20, x15\n"
+    "add x19, x19, x15\n"
+    "tbz x8, #2, 9f\n"
     "ld1 { v31.s }[0], [x26], #0x4\n"
     "ld1 { v30.s }[0], [x25], #0x4\n"
     "ld1 { v29.s }[0], [x24], #0x4\n"
@@ -713,7 +697,7 @@
     "ld1 { v26.s }[0], [x21], #0x4\n"
     "ld1 { v25.s }[0], [x20], #0x4\n"
     "ld1 { v24.s }[0], [x19], #0x4\n"
-    "tbz x3, #1, 8f\n"
+    "tbz x8, #1, 8f\n"
     "ld1 { v31.h }[2], [x26], #0x2\n"
     "ld1 { v30.h }[2], [x25], #0x2\n"
     "ld1 { v29.h }[2], [x24], #0x2\n"
@@ -722,7 +706,7 @@
     "ld1 { v26.h }[2], [x21], #0x2\n"
     "ld1 { v25.h }[2], [x20], #0x2\n"
     "ld1 { v24.h }[2], [x19], #0x2\n"
-    "tbz x3, #0, 11f\n"
+    "tbz x8, #0, 11f\n"
     "ld1 { v31.b }[6], [x26]\n"
     "ld1 { v30.b }[6], [x25]\n"
     "ld1 { v29.b }[6], [x24]\n"
@@ -733,7 +717,7 @@
     "ld1 { v24.b }[6], [x19]\n"
     "b 11f\n"
     "8:"  // Oddments: Initial loads: Bit 2: Bit 1: Unset
-    "tbz x3, #0, 11f\n"
+    "tbz x8, #0, 11f\n"
     "ld1 { v31.b }[4], [x26]\n"
     "ld1 { v30.b }[4], [x25]\n"
     "ld1 { v29.b }[4], [x24]\n"
@@ -744,7 +728,7 @@
     "ld1 { v24.b }[4], [x19]\n"
     "b 11f\n"
     "9:"  // Oddments: Initial loads: Bit 2: Unset
-    "tbz x3, #1, 10f\n"
+    "tbz x8, #1, 10f\n"
     "ld1 { v31.h }[0], [x26], #0x2\n"
     "ld1 { v30.h }[0], [x25], #0x2\n"
     "ld1 { v29.h }[0], [x24], #0x2\n"
@@ -753,7 +737,7 @@
     "ld1 { v26.h }[0], [x21], #0x2\n"
     "ld1 { v25.h }[0], [x20], #0x2\n"
     "ld1 { v24.h }[0], [x19], #0x2\n"
-    "tbz x3, #0, 11f\n"
+    "tbz x8, #0, 11f\n"
     "ld1 { v31.b }[2], [x26]\n"
     "ld1 { v30.b }[2], [x25]\n"
     "ld1 { v29.b }[2], [x24]\n"
@@ -764,7 +748,7 @@
     "ld1 { v24.b }[2], [x19]\n"
     "b 11f\n"
     "10:"  // Oddments: Initial loads: Bit 2: Unset: Bit 1: Unset
-    "tbz x3, #0, 11f\n"
+    "tbz x8, #0, 11f\n"
     "ld1 { v31.b }[0], [x26]\n"
     "ld1 { v30.b }[0], [x25]\n"
     "ld1 { v29.b }[0], [x24]\n"
@@ -774,646 +758,636 @@
     "ld1 { v25.b }[0], [x20]\n"
     "ld1 { v24.b }[0], [x19]\n"
     "11:"  // Oddments: Initial loads: Bit 2: End
-    "usubl v31.8h, v31.8b, v22.8b\n"
-    "ldr x22, [x7, #0x40]\n"
-    "add x22, x22, x4\n"
-    "usubl v30.8h, v30.8b, v22.8b\n"
-    "usubl v29.8h, v29.8b, v22.8b\n"
-    "usubl v28.8h, v28.8b, v22.8b\n"
-    "usubl v27.8h, v27.8b, v22.8b\n"
-    "usubl v26.8h, v26.8b, v22.8b\n"
-    "usubl v25.8h, v25.8b, v22.8b\n"
-    "usubl v24.8h, v24.8b, v22.8b\n"
-    "smlal v13.4s, v31.4h, v8.4h\n"
+    "usubl v31.8h, v31.8b, v12.8b\n"
+    "smlal v15.4s, v31.4h, v8.4h\n"
     "smlal2 v10.4s, v31.8h, v8.8h\n"
-    "smlal v19.4s, v31.4h, v6.4h\n"
-    "smlal2 v20.4s, v31.8h, v6.8h\n"
-    "smlal v11.4s, v31.4h, v2.4h\n"
-    "smlal2 v17.4s, v31.8h, v2.8h\n"
-    "smlal v18.4s, v31.4h, v0.4h\n"
-    "smlal2 v21.4s, v31.8h, v0.8h\n"
-    "smlal v13.4s, v30.4h, v0.4h\n"
+    "ldr x24, [x12, #0x40]\n"
+    "usubl v30.8h, v30.8b, v12.8b\n"
+    "smlal v15.4s, v30.4h, v0.4h\n"
     "smlal2 v10.4s, v30.8h, v0.8h\n"
-    "smlal v19.4s, v28.4h, v1.4h\n"
-    "smlal2 v20.4s, v28.8h, v1.8h\n"
-    "smlal v13.4s, v29.4h, v1.4h\n"
+    "add x24, x24, x15\n"
+    "usubl v29.8h, v29.8b, v12.8b\n"
+    "smlal v9.4s, v31.4h, v6.4h\n"
+    "smlal2 v16.4s, v31.8h, v6.8h\n"
+    "smlal v15.4s, v29.4h, v1.4h\n"
     "smlal2 v10.4s, v29.8h, v1.8h\n"
-    "smlal v19.4s, v27.4h, v2.4h\n"
-    "smlal2 v20.4s, v27.8h, v2.8h\n"
-    "smlal v13.4s, v26.4h, v3.4h\n"
+    "usubl v28.8h, v28.8b, v12.8b\n"
+    "usubl v26.8h, v26.8b, v12.8b\n"
+    "smlal v9.4s, v28.4h, v1.4h\n"
+    "smlal2 v16.4s, v28.8h, v1.8h\n"
+    "smlal v15.4s, v26.4h, v3.4h\n"
     "smlal2 v10.4s, v26.8h, v3.8h\n"
-    "smlal v19.4s, v24.4h, v0.4h\n"
-    "smlal2 v20.4s, v24.8h, v0.8h\n"
-    "smlal v13.4s, v25.4h, v4.4h\n"
+    "usubl v27.8h, v27.8b, v12.8b\n"
+    "usubl v25.8h, v25.8b, v12.8b\n"
+    "smlal v9.4s, v27.4h, v2.4h\n"
+    "smlal2 v16.4s, v27.8h, v2.8h\n"
+    "smlal v15.4s, v25.4h, v4.4h\n"
     "smlal2 v10.4s, v25.8h, v4.8h\n"
-    "smlal v13.4s, v24.4h, v2.4h\n"
+    "usubl v24.8h, v24.8b, v12.8b\n"
+    "smlal v22.4s, v31.4h, v2.4h\n"
+    "smlal2 v21.4s, v31.8h, v2.8h\n"
+    "smlal v23.4s, v31.4h, v0.4h\n"
+    "smlal2 v18.4s, v31.8h, v0.8h\n"
+    "smlal v15.4s, v24.4h, v2.4h\n"
     "smlal2 v10.4s, v24.8h, v2.8h\n"
-    "tbz x3, #2, 13f\n"
-    "ld1 { v29.s }[0], [x22], #0x4\n"
-    "tbz x3, #1, 12f\n"
-    "ld1 { v29.h }[2], [x22], #0x2\n"
-    "tbz x3, #0, 15f\n"
-    "ld1 { v29.b }[6], [x22]\n"
+    "smlal v9.4s, v24.4h, v0.4h\n"
+    "smlal2 v16.4s, v24.8h, v0.8h\n"
+    "tbz x8, #2, 13f\n"
+    "ld1 { v29.s }[0], [x24], #0x4\n"
+    "tbz x8, #1, 12f\n"
+    "ld1 { v29.h }[2], [x24], #0x2\n"
+    "tbz x8, #0, 15f\n"
+    "ld1 { v29.b }[6], [x24]\n"
     "b 15f\n"
     "12:"  // Oddments: Load (1, 3): Bit 2: Bit 1: Unset
-    "tbz x3, #0, 15f\n"
-    "ld1 { v29.b }[4], [x22]\n"
+    "tbz x8, #0, 15f\n"
+    "ld1 { v29.b }[4], [x24]\n"
     "b 15f\n"
     "13:"  // Oddments: Load (1, 3): Bit 2: Unset
-    "tbz x3, #1, 14f\n"
-    "ld1 { v29.h }[0], [x22], #0x2\n"
-    "tbz x3, #0, 15f\n"
-    "ld1 { v29.b }[2], [x22]\n"
+    "tbz x8, #1, 14f\n"
+    "ld1 { v29.h }[0], [x24], #0x2\n"
+    "tbz x8, #0, 15f\n"
+    "ld1 { v29.b }[2], [x24]\n"
     "b 15f\n"
     "14:"  // Oddments: Load (1, 3): Bit 2: Unset: Bit 1: Unset
-    "tbz x3, #0, 15f\n"
-    "ld1 { v29.b }[0], [x22]\n"
+    "tbz x8, #0, 15f\n"
+    "ld1 { v29.b }[0], [x24]\n"
     "15:"  // Oddments: Load (1, 3): Bit 2: End
-    "usubl v29.8h, v29.8b, v22.8b\n"
-    "ldr x21, [x7, #0x48]\n"
-    "smlal v19.4s, v29.4h, v4.4h\n"
-    "add x21, x21, x4\n"
-    "smlal2 v20.4s, v29.8h, v4.8h\n"
-    "tbz x3, #2, 17f\n"
-    "ld1 { v28.s }[0], [x21], #0x4\n"
-    "tbz x3, #1, 16f\n"
-    "ld1 { v28.h }[2], [x21], #0x2\n"
-    "tbz x3, #0, 19f\n"
-    "ld1 { v28.b }[6], [x21]\n"
+    "usubl v29.8h, v29.8b, v12.8b\n"
+    "ldr x23, [x12, #0x48]\n"
+    "smlal v9.4s, v29.4h, v4.4h\n"
+    "smlal2 v16.4s, v29.8h, v4.8h\n"
+    "add x23, x23, x15\n"
+    "tbz x8, #2, 17f\n"
+    "ld1 { v28.s }[0], [x23], #0x4\n"
+    "tbz x8, #1, 16f\n"
+    "ld1 { v28.h }[2], [x23], #0x2\n"
+    "tbz x8, #0, 19f\n"
+    "ld1 { v28.b }[6], [x23]\n"
     "b 19f\n"
     "16:"  // Oddments: Load (1, 4): Bit 2: Bit 1: Unset
-    "tbz x3, #0, 19f\n"
-    "ld1 { v28.b }[4], [x21]\n"
+    "tbz x8, #0, 19f\n"
+    "ld1 { v28.b }[4], [x23]\n"
     "b 19f\n"
     "17:"  // Oddments: Load (1, 4): Bit 2: Unset
-    "tbz x3, #1, 18f\n"
-    "ld1 { v28.h }[0], [x21], #0x2\n"
-    "tbz x3, #0, 19f\n"
-    "ld1 { v28.b }[2], [x21]\n"
+    "tbz x8, #1, 18f\n"
+    "ld1 { v28.h }[0], [x23], #0x2\n"
+    "tbz x8, #0, 19f\n"
+    "ld1 { v28.b }[2], [x23]\n"
     "b 19f\n"
     "18:"  // Oddments: Load (1, 4): Bit 2: Unset: Bit 1: Unset
-    "tbz x3, #0, 19f\n"
-    "ld1 { v28.b }[0], [x21]\n"
+    "tbz x8, #0, 19f\n"
+    "ld1 { v28.b }[0], [x23]\n"
     "19:"  // Oddments: Load (1, 4): Bit 2: End
-    "usubl v28.8h, v28.8b, v22.8b\n"
-    "ldr x20, [x7, #0x50]\n"
-    "smlal v19.4s, v28.4h, v5.4h\n"
-    "add x20, x20, x4\n"
-    "smlal2 v20.4s, v28.8h, v5.8h\n"
-    "tbz x3, #2, 21f\n"
-    "ld1 { v27.s }[0], [x20], #0x4\n"
-    "tbz x3, #1, 20f\n"
-    "ld1 { v27.h }[2], [x20], #0x2\n"
-    "tbz x3, #0, 23f\n"
-    "ld1 { v27.b }[6], [x20]\n"
+    "usubl v28.8h, v28.8b, v12.8b\n"
+    "ldr x21, [x12, #0x50]\n"
+    "smlal v9.4s, v28.4h, v5.4h\n"
+    "smlal2 v16.4s, v28.8h, v5.8h\n"
+    "add x21, x21, x15\n"
+    "tbz x8, #2, 21f\n"
+    "ld1 { v27.s }[0], [x21], #0x4\n"
+    "tbz x8, #1, 20f\n"
+    "ld1 { v27.h }[2], [x21], #0x2\n"
+    "tbz x8, #0, 23f\n"
+    "ld1 { v27.b }[6], [x21]\n"
     "b 23f\n"
     "20:"  // Oddments: Load (1, 2): Bit 2: Bit 1: Unset
-    "tbz x3, #0, 23f\n"
-    "ld1 { v27.b }[4], [x20]\n"
+    "tbz x8, #0, 23f\n"
+    "ld1 { v27.b }[4], [x21]\n"
     "b 23f\n"
     "21:"  // Oddments: Load (1, 2): Bit 2: Unset
-    "tbz x3, #1, 22f\n"
-    "ld1 { v27.h }[0], [x20], #0x2\n"
-    "tbz x3, #0, 23f\n"
-    "ld1 { v27.b }[2], [x20]\n"
+    "tbz x8, #1, 22f\n"
+    "ld1 { v27.h }[0], [x21], #0x2\n"
+    "tbz x8, #0, 23f\n"
+    "ld1 { v27.b }[2], [x21]\n"
     "b 23f\n"
     "22:"  // Oddments: Load (1, 2): Bit 2: Unset: Bit 1: Unset
-    "tbz x3, #0, 23f\n"
-    "ld1 { v27.b }[0], [x20]\n"
+    "tbz x8, #0, 23f\n"
+    "ld1 { v27.b }[0], [x21]\n"
     "23:"  // Oddments: Load (1, 2): Bit 2: End
-    "usubl v27.8h, v27.8b, v22.8b\n"
-    "ldr x19, [x7, #0x58]\n"
-    "smlal v13.4s, v27.4h, v5.4h\n"
-    "add x19, x19, x4\n"
+    "usubl v27.8h, v27.8b, v12.8b\n"
+    "ldr x19, [x12, #0x58]\n"
+    "smlal v15.4s, v27.4h, v5.4h\n"
     "smlal2 v10.4s, v27.8h, v5.8h\n"
-    "smlal v19.4s, v27.4h, v3.4h\n"
-    "smlal2 v20.4s, v27.8h, v3.8h\n"
-    "tbz x3, #2, 25f\n"
+    "smlal v9.4s, v27.4h, v3.4h\n"
+    "smlal2 v16.4s, v27.8h, v3.8h\n"
+    "add x19, x19, x15\n"
+    "tbz x8, #2, 25f\n"
     "ld1 { v26.s }[0], [x19], #0x4\n"
-    "tbz x3, #1, 24f\n"
+    "tbz x8, #1, 24f\n"
     "ld1 { v26.h }[2], [x19], #0x2\n"
-    "tbz x3, #0, 27f\n"
+    "tbz x8, #0, 27f\n"
     "ld1 { v26.b }[6], [x19]\n"
     "b 27f\n"
     "24:"  // Oddments: Load (3, 0): Bit 2: Bit 1: Unset
-    "tbz x3, #0, 27f\n"
+    "tbz x8, #0, 27f\n"
     "ld1 { v26.b }[4], [x19]\n"
     "b 27f\n"
     "25:"  // Oddments: Load (3, 0): Bit 2: Unset
-    "tbz x3, #1, 26f\n"
+    "tbz x8, #1, 26f\n"
     "ld1 { v26.h }[0], [x19], #0x2\n"
-    "tbz x3, #0, 27f\n"
+    "tbz x8, #0, 27f\n"
     "ld1 { v26.b }[2], [x19]\n"
     "b 27f\n"
     "26:"  // Oddments: Load (3, 0): Bit 2: Unset: Bit 1: Unset
-    "tbz x3, #0, 27f\n"
+    "tbz x8, #0, 27f\n"
     "ld1 { v26.b }[0], [x19]\n"
     "27:"  // Oddments: Load (3, 0): Bit 2: End
-    "usubl v26.8h, v26.8b, v22.8b\n"
-    "ldr x11, [x7, #0x60]\n"
-    "smlal v11.4s, v26.4h, v3.4h\n"
-    "add x11, x11, x4\n"
-    "smlal2 v17.4s, v26.8h, v3.8h\n"
-    "tbz x3, #2, 29f\n"
-    "ld1 { v25.s }[0], [x11], #0x4\n"
-    "tbz x3, #1, 28f\n"
-    "ld1 { v25.h }[2], [x11], #0x2\n"
-    "tbz x3, #0, 31f\n"
-    "ld1 { v25.b }[6], [x11]\n"
+    "usubl v26.8h, v26.8b, v12.8b\n"
+    "ldr x20, [x12, #0x60]\n"
+    "smlal v22.4s, v26.4h, v3.4h\n"
+    "smlal2 v21.4s, v26.8h, v3.8h\n"
+    "add x20, x20, x15\n"
+    "tbz x8, #2, 29f\n"
+    "ld1 { v25.s }[0], [x20], #0x4\n"
+    "tbz x8, #1, 28f\n"
+    "ld1 { v25.h }[2], [x20], #0x2\n"
+    "tbz x8, #0, 31f\n"
+    "ld1 { v25.b }[6], [x20]\n"
     "b 31f\n"
     "28:"  // Oddments: Load (2, 0): Bit 2: Bit 1: Unset
-    "tbz x3, #0, 31f\n"
-    "ld1 { v25.b }[4], [x11]\n"
+    "tbz x8, #0, 31f\n"
+    "ld1 { v25.b }[4], [x20]\n"
     "b 31f\n"
     "29:"  // Oddments: Load (2, 0): Bit 2: Unset
-    "tbz x3, #1, 30f\n"
-    "ld1 { v25.h }[0], [x11], #0x2\n"
-    "tbz x3, #0, 31f\n"
-    "ld1 { v25.b }[2], [x11]\n"
+    "tbz x8, #1, 30f\n"
+    "ld1 { v25.h }[0], [x20], #0x2\n"
+    "tbz x8, #0, 31f\n"
+    "ld1 { v25.b }[2], [x20]\n"
     "b 31f\n"
     "30:"  // Oddments: Load (2, 0): Bit 2: Unset: Bit 1: Unset
-    "tbz x3, #0, 31f\n"
-    "ld1 { v25.b }[0], [x11]\n"
+    "tbz x8, #0, 31f\n"
+    "ld1 { v25.b }[0], [x20]\n"
     "31:"  // Oddments: Load (2, 0): Bit 2: End
-    "usubl v25.8h, v25.8b, v22.8b\n"
-    "ldr x10, [x7, #0x68]\n"
-    "smlal v13.4s, v25.4h, v6.4h\n"
-    "add x10, x10, x4\n"
+    "usubl v25.8h, v25.8b, v12.8b\n"
+    "ldr x19, [x12, #0x68]\n"
+    "smlal v15.4s, v25.4h, v6.4h\n"
     "smlal2 v10.4s, v25.8h, v6.8h\n"
-    "smlal v11.4s, v25.4h, v0.4h\n"
-    "smlal2 v17.4s, v25.8h, v0.8h\n"
-    "tbz x3, #2, 33f\n"
-    "ld1 { v29.s }[0], [x10], #0x4\n"
-    "tbz x3, #1, 32f\n"
-    "ld1 { v29.h }[2], [x10], #0x2\n"
-    "tbz x3, #0, 35f\n"
-    "ld1 { v29.b }[6], [x10]\n"
+    "smlal v22.4s, v25.4h, v0.4h\n"
+    "smlal2 v21.4s, v25.8h, v0.8h\n"
+    "add x19, x19, x15\n"
+    "tbz x8, #2, 33f\n"
+    "ld1 { v29.s }[0], [x19], #0x4\n"
+    "tbz x8, #1, 32f\n"
+    "ld1 { v29.h }[2], [x19], #0x2\n"
+    "tbz x8, #0, 35f\n"
+    "ld1 { v29.b }[6], [x19]\n"
     "b 35f\n"
     "32:"  // Oddments: Load (3, 1): Bit 2: Bit 1: Unset
-    "tbz x3, #0, 35f\n"
-    "ld1 { v29.b }[4], [x10]\n"
+    "tbz x8, #0, 35f\n"
+    "ld1 { v29.b }[4], [x19]\n"
     "b 35f\n"
     "33:"  // Oddments: Load (3, 1): Bit 2: Unset
-    "tbz x3, #1, 34f\n"
-    "ld1 { v29.h }[0], [x10], #0x2\n"
-    "tbz x3, #0, 35f\n"
-    "ld1 { v29.b }[2], [x10]\n"
+    "tbz x8, #1, 34f\n"
+    "ld1 { v29.h }[0], [x19], #0x2\n"
+    "tbz x8, #0, 35f\n"
+    "ld1 { v29.b }[2], [x19]\n"
     "b 35f\n"
     "34:"  // Oddments: Load (3, 1): Bit 2: Unset: Bit 1: Unset
-    "tbz x3, #0, 35f\n"
-    "ld1 { v29.b }[0], [x10]\n"
+    "tbz x8, #0, 35f\n"
+    "ld1 { v29.b }[0], [x19]\n"
     "35:"  // Oddments: Load (3, 1): Bit 2: End
-    "usubl v29.8h, v29.8b, v22.8b\n"
-    "ldr x9, [x7, #0x70]\n"
-    "smlal v11.4s, v29.4h, v4.4h\n"
-    "add x9, x9, x4\n"
-    "smlal2 v17.4s, v29.8h, v4.8h\n"
-    "tbz x3, #2, 37f\n"
-    "ld1 { v24.s }[0], [x9], #0x4\n"
-    "tbz x3, #1, 36f\n"
-    "ld1 { v24.h }[2], [x9], #0x2\n"
-    "tbz x3, #0, 39f\n"
-    "ld1 { v24.b }[6], [x9]\n"
+    "usubl v29.8h, v29.8b, v12.8b\n"
+    "ldr x19, [x12, #0x70]\n"
+    "smlal v22.4s, v29.4h, v4.4h\n"
+    "smlal2 v21.4s, v29.8h, v4.8h\n"
+    "add x19, x19, x15\n"
+    "tbz x8, #2, 37f\n"
+    "ld1 { v24.s }[0], [x19], #0x4\n"
+    "tbz x8, #1, 36f\n"
+    "ld1 { v24.h }[2], [x19], #0x2\n"
+    "tbz x8, #0, 39f\n"
+    "ld1 { v24.b }[6], [x19]\n"
     "b 39f\n"
     "36:"  // Oddments: Load (2, 1): Bit 2: Bit 1: Unset
-    "tbz x3, #0, 39f\n"
-    "ld1 { v24.b }[4], [x9]\n"
+    "tbz x8, #0, 39f\n"
+    "ld1 { v24.b }[4], [x19]\n"
     "b 39f\n"
     "37:"  // Oddments: Load (2, 1): Bit 2: Unset
-    "tbz x3, #1, 38f\n"
-    "ld1 { v24.h }[0], [x9], #0x2\n"
-    "tbz x3, #0, 39f\n"
-    "ld1 { v24.b }[2], [x9]\n"
+    "tbz x8, #1, 38f\n"
+    "ld1 { v24.h }[0], [x19], #0x2\n"
+    "tbz x8, #0, 39f\n"
+    "ld1 { v24.b }[2], [x19]\n"
     "b 39f\n"
     "38:"  // Oddments: Load (2, 1): Bit 2: Unset: Bit 1: Unset
-    "tbz x3, #0, 39f\n"
-    "ld1 { v24.b }[0], [x9]\n"
+    "tbz x8, #0, 39f\n"
+    "ld1 { v24.b }[0], [x19]\n"
     "39:"  // Oddments: Load (2, 1): Bit 2: End
-    "usubl v24.8h, v24.8b, v22.8b\n"
-    "ldr x28, [x7, #0x78]\n"
-    "smlal v13.4s, v24.4h, v7.4h\n"
-    "add x28, x28, x4\n"
+    "usubl v24.8h, v24.8b, v12.8b\n"
+    "ldr x22, [x12, #0x78]\n"
+    "smlal v15.4s, v24.4h, v7.4h\n"
     "smlal2 v10.4s, v24.8h, v7.8h\n"
-    "smlal v11.4s, v24.4h, v1.4h\n"
-    "smlal2 v17.4s, v24.8h, v1.8h\n"
-    "tbz x3, #2, 41f\n"
-    "ld1 { v27.s }[0], [x28], #0x4\n"
-    "tbz x3, #1, 40f\n"
-    "ld1 { v27.h }[2], [x28], #0x2\n"
-    "tbz x3, #0, 43f\n"
-    "ld1 { v27.b }[6], [x28]\n"
+    "smlal v22.4s, v24.4h, v1.4h\n"
+    "smlal2 v21.4s, v24.8h, v1.8h\n"
+    "add x22, x22, x15\n"
+    "tbz x8, #2, 41f\n"
+    "ld1 { v27.s }[0], [x22], #0x4\n"
+    "tbz x8, #1, 40f\n"
+    "ld1 { v27.h }[2], [x22], #0x2\n"
+    "tbz x8, #0, 43f\n"
+    "ld1 { v27.b }[6], [x22]\n"
     "b 43f\n"
     "40:"  // Oddments: Load (3, 3): Bit 2: Bit 1: Unset
-    "tbz x3, #0, 43f\n"
-    "ld1 { v27.b }[4], [x28]\n"
+    "tbz x8, #0, 43f\n"
+    "ld1 { v27.b }[4], [x22]\n"
     "b 43f\n"
     "41:"  // Oddments: Load (3, 3): Bit 2: Unset
-    "tbz x3, #1, 42f\n"
-    "ld1 { v27.h }[0], [x28], #0x2\n"
-    "tbz x3, #0, 43f\n"
-    "ld1 { v27.b }[2], [x28]\n"
+    "tbz x8, #1, 42f\n"
+    "ld1 { v27.h }[0], [x22], #0x2\n"
+    "tbz x8, #0, 43f\n"
+    "ld1 { v27.b }[2], [x22]\n"
     "b 43f\n"
     "42:"  // Oddments: Load (3, 3): Bit 2: Unset: Bit 1: Unset
-    "tbz x3, #0, 43f\n"
-    "ld1 { v27.b }[0], [x28]\n"
+    "tbz x8, #0, 43f\n"
+    "ld1 { v27.b }[0], [x22]\n"
     "43:"  // Oddments: Load (3, 3): Bit 2: End
-    "usubl v27.8h, v27.8b, v22.8b\n"
-    "ldr x27, [x7, #0x80]\n"
-    "smlal v18.4s, v27.4h, v4.4h\n"
-    "add x27, x27, x4\n"
-    "smlal2 v21.4s, v27.8h, v4.8h\n"
-    "tbz x3, #2, 45f\n"
-    "ld1 { v28.s }[0], [x27], #0x4\n"
-    "tbz x3, #1, 44f\n"
-    "ld1 { v28.h }[2], [x27], #0x2\n"
-    "tbz x3, #0, 47f\n"
-    "ld1 { v28.b }[6], [x27]\n"
+    "usubl v27.8h, v27.8b, v12.8b\n"
+    "ldr x21, [x12, #0x80]\n"
+    "smlal v23.4s, v27.4h, v4.4h\n"
+    "smlal2 v18.4s, v27.8h, v4.8h\n"
+    "add x21, x21, x15\n"
+    "tbz x8, #2, 45f\n"
+    "ld1 { v28.s }[0], [x21], #0x4\n"
+    "tbz x8, #1, 44f\n"
+    "ld1 { v28.h }[2], [x21], #0x2\n"
+    "tbz x8, #0, 47f\n"
+    "ld1 { v28.b }[6], [x21]\n"
     "b 47f\n"
     "44:"  // Oddments: Load (2, 3): Bit 2: Bit 1: Unset
-    "tbz x3, #0, 47f\n"
-    "ld1 { v28.b }[4], [x27]\n"
+    "tbz x8, #0, 47f\n"
+    "ld1 { v28.b }[4], [x21]\n"
     "b 47f\n"
     "45:"  // Oddments: Load (2, 3): Bit 2: Unset
-    "tbz x3, #1, 46f\n"
-    "ld1 { v28.h }[0], [x27], #0x2\n"
-    "tbz x3, #0, 47f\n"
-    "ld1 { v28.b }[2], [x27]\n"
+    "tbz x8, #1, 46f\n"
+    "ld1 { v28.h }[0], [x21], #0x2\n"
+    "tbz x8, #0, 47f\n"
+    "ld1 { v28.b }[2], [x21]\n"
     "b 47f\n"
     "46:"  // Oddments: Load (2, 3): Bit 2: Unset: Bit 1: Unset
-    "tbz x3, #0, 47f\n"
-    "ld1 { v28.b }[0], [x27]\n"
+    "tbz x8, #0, 47f\n"
+    "ld1 { v28.b }[0], [x21]\n"
     "47:"  // Oddments: Load (2, 3): Bit 2: End
-    "usubl v28.8h, v28.8b, v22.8b\n"
-    "ldr x26, [x7, #0x88]\n"
-    "smlal v19.4s, v28.4h, v7.4h\n"
-    "add x26, x26, x4\n"
-    "smlal2 v20.4s, v28.8h, v7.8h\n"
-    "smlal v18.4s, v28.4h, v1.4h\n"
-    "smlal2 v21.4s, v28.8h, v1.8h\n"
-    "tbz x3, #2, 49f\n"
-    "ld1 { v26.s }[0], [x26], #0x4\n"
-    "tbz x3, #1, 48f\n"
-    "ld1 { v26.h }[2], [x26], #0x2\n"
-    "tbz x3, #0, 51f\n"
-    "ld1 { v26.b }[6], [x26]\n"
+    "usubl v28.8h, v28.8b, v12.8b\n"
+    "ldr x20, [x12, #0x88]\n"
+    "smlal v9.4s, v28.4h, v7.4h\n"
+    "smlal2 v16.4s, v28.8h, v7.8h\n"
+    "smlal v23.4s, v28.4h, v1.4h\n"
+    "smlal2 v18.4s, v28.8h, v1.8h\n"
+    "add x20, x20, x15\n"
+    "tbz x8, #2, 49f\n"
+    "ld1 { v26.s }[0], [x20], #0x4\n"
+    "tbz x8, #1, 48f\n"
+    "ld1 { v26.h }[2], [x20], #0x2\n"
+    "tbz x8, #0, 51f\n"
+    "ld1 { v26.b }[6], [x20]\n"
     "b 51f\n"
     "48:"  // Oddments: Load (3, 4): Bit 2: Bit 1: Unset
-    "tbz x3, #0, 51f\n"
-    "ld1 { v26.b }[4], [x26]\n"
+    "tbz x8, #0, 51f\n"
+    "ld1 { v26.b }[4], [x20]\n"
     "b 51f\n"
     "49:"  // Oddments: Load (3, 4): Bit 2: Unset
-    "tbz x3, #1, 50f\n"
-    "ld1 { v26.h }[0], [x26], #0x2\n"
-    "tbz x3, #0, 51f\n"
-    "ld1 { v26.b }[2], [x26]\n"
+    "tbz x8, #1, 50f\n"
+    "ld1 { v26.h }[0], [x20], #0x2\n"
+    "tbz x8, #0, 51f\n"
+    "ld1 { v26.b }[2], [x20]\n"
     "b 51f\n"
     "50:"  // Oddments: Load (3, 4): Bit 2: Unset: Bit 1: Unset
-    "tbz x3, #0, 51f\n"
-    "ld1 { v26.b }[0], [x26]\n"
+    "tbz x8, #0, 51f\n"
+    "ld1 { v26.b }[0], [x20]\n"
     "51:"  // Oddments: Load (3, 4): Bit 2: End
-    "usubl v26.8h, v26.8b, v22.8b\n"
-    "ldr x25, [x7, #0x90]\n"
-    "smlal v18.4s, v26.4h, v5.4h\n"
-    "add x25, x25, x4\n"
-    "smlal2 v21.4s, v26.8h, v5.8h\n"
-    "tbz x3, #2, 53f\n"
-    "ld1 { v25.s }[0], [x25], #0x4\n"
-    "tbz x3, #1, 52f\n"
-    "ld1 { v25.h }[2], [x25], #0x2\n"
-    "tbz x3, #0, 55f\n"
-    "ld1 { v25.b }[6], [x25]\n"
+    "usubl v26.8h, v26.8b, v12.8b\n"
+    "ldr x23, [x12, #0x90]\n"
+    "smlal v23.4s, v26.4h, v5.4h\n"
+    "smlal2 v18.4s, v26.8h, v5.8h\n"
+    "add x23, x23, x15\n"
+    "tbz x8, #2, 53f\n"
+    "ld1 { v25.s }[0], [x23], #0x4\n"
+    "tbz x8, #1, 52f\n"
+    "ld1 { v25.h }[2], [x23], #0x2\n"
+    "tbz x8, #0, 55f\n"
+    "ld1 { v25.b }[6], [x23]\n"
     "b 55f\n"
     "52:"  // Oddments: Load (4, 0): Bit 2: Bit 1: Unset
-    "tbz x3, #0, 55f\n"
-    "ld1 { v25.b }[4], [x25]\n"
+    "tbz x8, #0, 55f\n"
+    "ld1 { v25.b }[4], [x23]\n"
     "b 55f\n"
     "53:"  // Oddments: Load (4, 0): Bit 2: Unset
-    "tbz x3, #1, 54f\n"
-    "ld1 { v25.h }[0], [x25], #0x2\n"
-    "tbz x3, #0, 55f\n"
-    "ld1 { v25.b }[2], [x25]\n"
+    "tbz x8, #1, 54f\n"
+    "ld1 { v25.h }[0], [x23], #0x2\n"
+    "tbz x8, #0, 55f\n"
+    "ld1 { v25.b }[2], [x23]\n"
     "b 55f\n"
     "54:"  // Oddments: Load (4, 0): Bit 2: Unset: Bit 1: Unset
-    "tbz x3, #0, 55f\n"
-    "ld1 { v25.b }[0], [x25]\n"
+    "tbz x8, #0, 55f\n"
+    "ld1 { v25.b }[0], [x23]\n"
     "55:"  // Oddments: Load (4, 0): Bit 2: End
-    "usubl v25.8h, v25.8b, v22.8b\n"
-    "ldr x24, [x7, #0x98]\n"
-    "smlal v11.4s, v25.4h, v6.4h\n"
-    "add x24, x24, x4\n"
-    "smlal2 v17.4s, v25.8h, v6.8h\n"
-    "tbz x3, #2, 57f\n"
+    "usubl v25.8h, v25.8b, v12.8b\n"
+    "ldr x24, [x12, #0x98]\n"
+    "smlal v22.4s, v25.4h, v6.4h\n"
+    "smlal2 v21.4s, v25.8h, v6.8h\n"
+    "add x24, x24, x15\n"
+    "tbz x8, #2, 57f\n"
     "ld1 { v29.s }[0], [x24], #0x4\n"
-    "tbz x3, #1, 56f\n"
+    "tbz x8, #1, 56f\n"
     "ld1 { v29.h }[2], [x24], #0x2\n"
-    "tbz x3, #0, 59f\n"
+    "tbz x8, #0, 59f\n"
     "ld1 { v29.b }[6], [x24]\n"
     "b 59f\n"
     "56:"  // Oddments: Load (2, 4): Bit 2: Bit 1: Unset
-    "tbz x3, #0, 59f\n"
+    "tbz x8, #0, 59f\n"
     "ld1 { v29.b }[4], [x24]\n"
     "b 59f\n"
     "57:"  // Oddments: Load (2, 4): Bit 2: Unset
-    "tbz x3, #1, 58f\n"
+    "tbz x8, #1, 58f\n"
     "ld1 { v29.h }[0], [x24], #0x2\n"
-    "tbz x3, #0, 59f\n"
+    "tbz x8, #0, 59f\n"
     "ld1 { v29.b }[2], [x24]\n"
     "b 59f\n"
     "58:"  // Oddments: Load (2, 4): Bit 2: Unset: Bit 1: Unset
-    "tbz x3, #0, 59f\n"
+    "tbz x8, #0, 59f\n"
     "ld1 { v29.b }[0], [x24]\n"
     "59:"  // Oddments: Load (2, 4): Bit 2: End
-    "usubl v29.8h, v29.8b, v22.8b\n"
-    "ldr x23, [x7, #0xa0]\n"
-    "smlal v19.4s, v29.4h, v8.4h\n"
-    "add x23, x23, x4\n"
-    "smlal2 v20.4s, v29.8h, v8.8h\n"
-    "smlal v18.4s, v29.4h, v2.4h\n"
-    "smlal2 v21.4s, v29.8h, v2.8h\n"
-    "tbz x3, #2, 61f\n"
-    "ld1 { v27.s }[0], [x23], #0x4\n"
-    "tbz x3, #1, 60f\n"
-    "ld1 { v27.h }[2], [x23], #0x2\n"
-    "tbz x3, #0, 63f\n"
-    "ld1 { v27.b }[6], [x23]\n"
+    "usubl v29.8h, v29.8b, v12.8b\n"
+    "ldr x19, [x12, #0xa0]\n"
+    "smlal v9.4s, v29.4h, v8.4h\n"
+    "smlal2 v16.4s, v29.8h, v8.8h\n"
+    "smlal v23.4s, v29.4h, v2.4h\n"
+    "smlal2 v18.4s, v29.8h, v2.8h\n"
+    "add x19, x19, x15\n"
+    "tbz x8, #2, 61f\n"
+    "ld1 { v27.s }[0], [x19], #0x4\n"
+    "tbz x8, #1, 60f\n"
+    "ld1 { v27.h }[2], [x19], #0x2\n"
+    "tbz x8, #0, 63f\n"
+    "ld1 { v27.b }[6], [x19]\n"
     "b 63f\n"
     "60:"  // Oddments: Load (4, 1): Bit 2: Bit 1: Unset
-    "tbz x3, #0, 63f\n"
-    "ld1 { v27.b }[4], [x23]\n"
+    "tbz x8, #0, 63f\n"
+    "ld1 { v27.b }[4], [x19]\n"
     "b 63f\n"
     "61:"  // Oddments: Load (4, 1): Bit 2: Unset
-    "tbz x3, #1, 62f\n"
-    "ld1 { v27.h }[0], [x23], #0x2\n"
-    "tbz x3, #0, 63f\n"
-    "ld1 { v27.b }[2], [x23]\n"
+    "tbz x8, #1, 62f\n"
+    "ld1 { v27.h }[0], [x19], #0x2\n"
+    "tbz x8, #0, 63f\n"
+    "ld1 { v27.b }[2], [x19]\n"
     "b 63f\n"
     "62:"  // Oddments: Load (4, 1): Bit 2: Unset: Bit 1: Unset
-    "tbz x3, #0, 63f\n"
-    "ld1 { v27.b }[0], [x23]\n"
+    "tbz x8, #0, 63f\n"
+    "ld1 { v27.b }[0], [x19]\n"
     "63:"  // Oddments: Load (4, 1): Bit 2: End
-    "usubl v27.8h, v27.8b, v22.8b\n"
-    "ldr x22, [x7, #0xa8]\n"
-    "smlal v11.4s, v27.4h, v7.4h\n"
-    "add x22, x22, x4\n"
-    "smlal2 v17.4s, v27.8h, v7.8h\n"
-    "tbz x3, #2, 65f\n"
+    "usubl v27.8h, v27.8b, v12.8b\n"
+    "ldr x22, [x12, #0xa8]\n"
+    "smlal v22.4s, v27.4h, v7.4h\n"
+    "smlal2 v21.4s, v27.8h, v7.8h\n"
+    "add x22, x22, x15\n"
+    "tbz x8, #2, 65f\n"
     "ld1 { v24.s }[0], [x22], #0x4\n"
-    "tbz x3, #1, 64f\n"
+    "tbz x8, #1, 64f\n"
     "ld1 { v24.h }[2], [x22], #0x2\n"
-    "tbz x3, #0, 67f\n"
+    "tbz x8, #0, 67f\n"
     "ld1 { v24.b }[6], [x22]\n"
     "b 67f\n"
     "64:"  // Oddments: Load (3, 2): Bit 2: Bit 1: Unset
-    "tbz x3, #0, 67f\n"
+    "tbz x8, #0, 67f\n"
     "ld1 { v24.b }[4], [x22]\n"
     "b 67f\n"
     "65:"  // Oddments: Load (3, 2): Bit 2: Unset
-    "tbz x3, #1, 66f\n"
+    "tbz x8, #1, 66f\n"
     "ld1 { v24.h }[0], [x22], #0x2\n"
-    "tbz x3, #0, 67f\n"
+    "tbz x8, #0, 67f\n"
     "ld1 { v24.b }[2], [x22]\n"
     "b 67f\n"
     "66:"  // Oddments: Load (3, 2): Bit 2: Unset: Bit 1: Unset
-    "tbz x3, #0, 67f\n"
+    "tbz x8, #0, 67f\n"
     "ld1 { v24.b }[0], [x22]\n"
     "67:"  // Oddments: Load (3, 2): Bit 2: End
-    "usubl v24.8h, v24.8b, v22.8b\n"
-    "ldr x21, [x7, #0xb0]\n"
-    "smlal v11.4s, v24.4h, v5.4h\n"
-    "add x21, x21, x4\n"
-    "smlal2 v17.4s, v24.8h, v5.8h\n"
-    "smlal v18.4s, v24.4h, v3.4h\n"
-    "smlal2 v21.4s, v24.8h, v3.8h\n"
-    "tbz x3, #2, 69f\n"
+    "usubl v24.8h, v24.8b, v12.8b\n"
+    "ldr x21, [x12, #0xb0]\n"
+    "smlal v22.4s, v24.4h, v5.4h\n"
+    "smlal2 v21.4s, v24.8h, v5.8h\n"
+    "smlal v23.4s, v24.4h, v3.4h\n"
+    "smlal2 v18.4s, v24.8h, v3.8h\n"
+    "add x21, x21, x15\n"
+    "tbz x8, #2, 69f\n"
     "ld1 { v26.s }[0], [x21], #0x4\n"
-    "tbz x3, #1, 68f\n"
+    "tbz x8, #1, 68f\n"
     "ld1 { v26.h }[2], [x21], #0x2\n"
-    "tbz x3, #0, 71f\n"
+    "tbz x8, #0, 71f\n"
     "ld1 { v26.b }[6], [x21]\n"
     "b 71f\n"
     "68:"  // Oddments: Load (4, 3): Bit 2: Bit 1: Unset
-    "tbz x3, #0, 71f\n"
+    "tbz x8, #0, 71f\n"
     "ld1 { v26.b }[4], [x21]\n"
     "b 71f\n"
     "69:"  // Oddments: Load (4, 3): Bit 2: Unset
-    "tbz x3, #1, 70f\n"
+    "tbz x8, #1, 70f\n"
     "ld1 { v26.h }[0], [x21], #0x2\n"
-    "tbz x3, #0, 71f\n"
+    "tbz x8, #0, 71f\n"
     "ld1 { v26.b }[2], [x21]\n"
     "b 71f\n"
     "70:"  // Oddments: Load (4, 3): Bit 2: Unset: Bit 1: Unset
-    "tbz x3, #0, 71f\n"
+    "tbz x8, #0, 71f\n"
     "ld1 { v26.b }[0], [x21]\n"
     "71:"  // Oddments: Load (4, 3): Bit 2: End
-    "usubl v26.8h, v26.8b, v22.8b\n"
-    "ldr x20, [x7, #0xb8]\n"
-    "smlal v18.4s, v26.4h, v7.4h\n"
-    "add x20, x20, x4\n"
-    "smlal2 v21.4s, v26.8h, v7.8h\n"
-    "tbz x3, #2, 73f\n"
+    "usubl v26.8h, v26.8b, v12.8b\n"
+    "ldr x20, [x12, #0xb8]\n"
+    "smlal v23.4s, v26.4h, v7.4h\n"
+    "smlal2 v18.4s, v26.8h, v7.8h\n"
+    "add x20, x20, x15\n"
+    "tbz x8, #2, 73f\n"
     "ld1 { v25.s }[0], [x20], #0x4\n"
-    "tbz x3, #1, 72f\n"
+    "tbz x8, #1, 72f\n"
     "ld1 { v25.h }[2], [x20], #0x2\n"
-    "tbz x3, #0, 75f\n"
+    "tbz x8, #0, 75f\n"
     "ld1 { v25.b }[6], [x20]\n"
     "b 75f\n"
     "72:"  // Oddments: Load (4, 2): Bit 2: Bit 1: Unset
-    "tbz x3, #0, 75f\n"
+    "tbz x8, #0, 75f\n"
     "ld1 { v25.b }[4], [x20]\n"
     "b 75f\n"
     "73:"  // Oddments: Load (4, 2): Bit 2: Unset
-    "tbz x3, #1, 74f\n"
+    "tbz x8, #1, 74f\n"
     "ld1 { v25.h }[0], [x20], #0x2\n"
-    "tbz x3, #0, 75f\n"
+    "tbz x8, #0, 75f\n"
     "ld1 { v25.b }[2], [x20]\n"
     "b 75f\n"
     "74:"  // Oddments: Load (4, 2): Bit 2: Unset: Bit 1: Unset
-    "tbz x3, #0, 75f\n"
+    "tbz x8, #0, 75f\n"
     "ld1 { v25.b }[0], [x20]\n"
     "75:"  // Oddments: Load (4, 2): Bit 2: End
-    "usubl v25.8h, v25.8b, v22.8b\n"
-    "ldr x19, [x7, #0xc0]\n"
-    "smlal v11.4s, v25.4h, v8.4h\n"
-    "add x19, x19, x4\n"
-    "smlal2 v17.4s, v25.8h, v8.8h\n"
-    "smlal v18.4s, v25.4h, v6.4h\n"
-    "smlal2 v21.4s, v25.8h, v6.8h\n"
-    "tbz x3, #2, 77f\n"
+    "usubl v25.8h, v25.8b, v12.8b\n"
+    "ldr x19, [x12, #0xc0]\n"
+    "smlal v22.4s, v25.4h, v8.4h\n"
+    "smlal2 v21.4s, v25.8h, v8.8h\n"
+    "smlal v23.4s, v25.4h, v6.4h\n"
+    "smlal2 v18.4s, v25.8h, v6.8h\n"
+    "add x19, x19, x15\n"
+    "tbz x8, #2, 77f\n"
     "ld1 { v29.s }[0], [x19], #0x4\n"
-    "tbz x3, #1, 76f\n"
+    "tbz x8, #1, 76f\n"
     "ld1 { v29.h }[2], [x19], #0x2\n"
-    "tbz x3, #0, 79f\n"
+    "tbz x8, #0, 79f\n"
     "ld1 { v29.b }[6], [x19]\n"
     "b 79f\n"
     "76:"  // Oddments: Load (4, 4): Bit 2: Bit 1: Unset
-    "tbz x3, #0, 79f\n"
+    "tbz x8, #0, 79f\n"
     "ld1 { v29.b }[4], [x19]\n"
     "b 79f\n"
     "77:"  // Oddments: Load (4, 4): Bit 2: Unset
-    "tbz x3, #1, 78f\n"
+    "tbz x8, #1, 78f\n"
     "ld1 { v29.h }[0], [x19], #0x2\n"
-    "tbz x3, #0, 79f\n"
+    "tbz x8, #0, 79f\n"
     "ld1 { v29.b }[2], [x19]\n"
     "b 79f\n"
     "78:"  // Oddments: Load (4, 4): Bit 2: Unset: Bit 1: Unset
-    "tbz x3, #0, 79f\n"
+    "tbz x8, #0, 79f\n"
     "ld1 { v29.b }[0], [x19]\n"
     "79:"  // Oddments: Load (4, 4): Bit 2: End
-    "usubl v29.8h, v29.8b, v22.8b\n"
-    "smlal v18.4s, v29.4h, v8.4h\n"
-    "smlal2 v21.4s, v29.8h, v8.8h\n"
-    "tbz x3, #2, 81f\n"
-    "ld1 { v31.4s }, [x8], #0x10\n"
-    "ld1 { v30.4s }, [x16], #0x10\n"
-    "tbz x3, #1, 80f\n"
-    "ld1 { v23.d }[0], [x8], #0x8\n"
-    "ld1 { v9.d }[0], [x16], #0x8\n"
-    "tbz x3, #0, 83f\n"
-    "ld1 { v23.s }[2], [x8]\n"
-    "ld1 { v9.s }[2], [x16]\n"
+    "usubl v29.8h, v29.8b, v12.8b\n"
+    "smlal v23.4s, v29.4h, v8.4h\n"
+    "smlal2 v18.4s, v29.8h, v8.8h\n"
+    "tbz x8, #2, 81f\n"
+    "ld1 { v19.4s }, [x13], #0x10\n"
+    "ld1 { v0.4s }, [x11], #0x10\n"
+    "tbz x8, #1, 80f\n"
+    "ld1 { v4.d }[0], [x13], #0x8\n"
+    "ld1 { v31.d }[0], [x11], #0x8\n"
+    "tbz x8, #0, 83f\n"
+    "ld1 { v4.s }[2], [x13]\n"
+    "ld1 { v31.s }[2], [x11]\n"
     "b 83f\n"
     "80:"  // Oddments: Load requant params: Bit 2: Bit 1: Unset
-    "tbz x3, #0, 83f\n"
-    "ld1 { v23.s }[0], [x8]\n"
-    "ld1 { v9.s }[0], [x16]\n"
+    "tbz x8, #0, 83f\n"
+    "ld1 { v4.s }[0], [x13]\n"
+    "ld1 { v31.s }[0], [x11]\n"
     "b 83f\n"
     "81:"  // Oddments: Load requant params: Bit 2: Unset
-    "tbz x3, #1, 82f\n"
-    "ld1 { v31.d }[0], [x8], #0x8\n"
-    "ld1 { v30.d }[0], [x16], #0x8\n"
-    "tbz x3, #0, 83f\n"
-    "ld1 { v31.s }[2], [x8]\n"
-    "ld1 { v30.s }[2], [x16]\n"
+    "tbz x8, #1, 82f\n"
+    "ld1 { v19.d }[0], [x13], #0x8\n"
+    "ld1 { v0.d }[0], [x11], #0x8\n"
+    "tbz x8, #0, 83f\n"
+    "ld1 { v19.s }[2], [x13]\n"
+    "ld1 { v0.s }[2], [x11]\n"
     "b 83f\n"
     "82:"  // Oddments: Load requant params: Bit 2: Unset: Bit 1: Unset
-    "tbz x3, #0, 83f\n"
-    "ld1 { v31.s }[0], [x8]\n"
-    "ld1 { v30.s }[0], [x16]\n"
+    "tbz x8, #0, 83f\n"
+    "ld1 { v19.s }[0], [x13]\n"
+    "ld1 { v0.s }[0], [x11]\n"
     "83:"  // Oddments: Load requant params: Bit 2: End
-    "sqrdmulh v13.4s, v13.4s, v31.4s\n"
-    "add x15, x15, x6\n"
-    "sqrdmulh v10.4s, v10.4s, v23.4s\n"
-    "add x14, x14, x6\n"
-    "sqrdmulh v19.4s, v19.4s, v31.4s\n"
-    "add x13, x13, x6\n"
-    "sqrdmulh v20.4s, v20.4s, v23.4s\n"
-    "add x12, x12, x6\n"
-    "sqrdmulh v11.4s, v11.4s, v31.4s\n"
-    "and v27.16b, v13.16b, v30.16b\n"
-    "and v7.16b, v10.16b, v9.16b\n"
-    "and v6.16b, v19.16b, v30.16b\n"
-    "sshr v27.4s, v27.4s, #0x1f\n"
-    "sshr v7.4s, v7.4s, #0x1f\n"
-    "sshr v6.4s, v6.4s, #0x1f\n"
-    "sqadd v13.4s, v13.4s, v27.4s\n"
-    "sqadd v10.4s, v10.4s, v7.4s\n"
-    "sqadd v19.4s, v19.4s, v6.4s\n"
-    "and v3.16b, v20.16b, v9.16b\n"
-    "srshl v13.4s, v13.4s, v30.4s\n"
-    "srshl v10.4s, v10.4s, v9.4s\n"
-    "srshl v19.4s, v19.4s, v30.4s\n"
-    "sshr v3.4s, v3.4s, #0x1f\n"
-    "add v13.4s, v13.4s, v14.4s\n"
-    "add v10.4s, v10.4s, v14.4s\n"
-    "add v19.4s, v19.4s, v14.4s\n"
-    "smin v13.4s, v13.4s, v15.4s\n"
-    "smin v10.4s, v10.4s, v15.4s\n"
-    "smin v19.4s, v19.4s, v15.4s\n"
-    "smax v13.4s, v13.4s, v16.4s\n"
-    "smax v10.4s, v10.4s, v16.4s\n"
-    "smax v19.4s, v19.4s, v16.4s\n"
-    "sqadd v20.4s, v20.4s, v3.4s\n"
-    "uzp1 v13.16b, v13.16b, v10.16b\n"
-    "and v28.16b, v11.16b, v30.16b\n"
-    "uzp1 v13.16b, v13.16b, v13.16b\n"
-    "srshl v20.4s, v20.4s, v9.4s\n"
+    "sqdmulh v15.4s, v15.4s, v19.4s\n"
+    "sqdmulh v9.4s, v9.4s, v19.4s\n"
+    "add x10, x10, x14\n"
+    "add x9, x9, x14\n"
+    "sqdmulh v22.4s, v22.4s, v19.4s\n"
+    "sqdmulh v23.4s, v23.4s, v19.4s\n"
+    "add x28, x28, x14\n"
+    "add x27, x27, x14\n"
+    "and v30.16b, v15.16b, v0.16b\n"
+    "sqdmulh v10.4s, v10.4s, v4.4s\n"
+    "and v28.16b, v9.16b, v0.16b\n"
+    "sqdmulh v16.4s, v16.4s, v4.4s\n"
+    "and v29.16b, v22.16b, v0.16b\n"
+    "sqdmulh v21.4s, v21.4s, v4.4s\n"
+    "and v20.16b, v23.16b, v0.16b\n"
+    "sqdmulh v18.4s, v18.4s, v4.4s\n"
+    "sshr v30.4s, v30.4s, #0x1f\n"
+    "and v19.16b, v10.16b, v31.16b\n"
     "sshr v28.4s, v28.4s, #0x1f\n"
-    "sqrdmulh v17.4s, v17.4s, v23.4s\n"
-    "sqrdmulh v18.4s, v18.4s, v31.4s\n"
-    "add v20.4s, v20.4s, v14.4s\n"
-    "sqadd v11.4s, v11.4s, v28.4s\n"
-    "and v26.16b, v17.16b, v9.16b\n"
-    "smin v20.4s, v20.4s, v15.4s\n"
-    "and v8.16b, v18.16b, v30.16b\n"
-    "srshl v11.4s, v11.4s, v30.4s\n"
-    "smax v20.4s, v20.4s, v16.4s\n"
+    "and v4.16b, v16.16b, v31.16b\n"
+    "sshr v29.4s, v29.4s, #0x1f\n"
+    "and v5.16b, v21.16b, v31.16b\n"
+    "sshr v20.4s, v20.4s, #0x1f\n"
+    "and v26.16b, v18.16b, v31.16b\n"
+    "sqadd v15.4s, v15.4s, v30.4s\n"
+    "sshr v19.4s, v19.4s, #0x1f\n"
+    "sqadd v9.4s, v9.4s, v28.4s\n"
+    "sshr v4.4s, v4.4s, #0x1f\n"
+    "sqadd v22.4s, v22.4s, v29.4s\n"
+    "sshr v5.4s, v5.4s, #0x1f\n"
+    "sqadd v23.4s, v23.4s, v20.4s\n"
     "sshr v26.4s, v26.4s, #0x1f\n"
-    "sshr v8.4s, v8.4s, #0x1f\n"
-    "uzp1 v19.16b, v19.16b, v20.16b\n"
-    "add v11.4s, v11.4s, v14.4s\n"
-    "uzp1 v19.16b, v19.16b, v19.16b\n"
-    "sqadd v17.4s, v17.4s, v26.4s\n"
-    "smin v11.4s, v11.4s, v15.4s\n"
-    "sqadd v18.4s, v18.4s, v8.4s\n"
-    "sqrdmulh v21.4s, v21.4s, v23.4s\n"
-    "smax v11.4s, v11.4s, v16.4s\n"
-    "srshl v17.4s, v17.4s, v9.4s\n"
-    "srshl v18.4s, v18.4s, v30.4s\n"
-    "and v27.16b, v21.16b, v9.16b\n"
-    "add v17.4s, v17.4s, v14.4s\n"
-    "add v18.4s, v18.4s, v14.4s\n"
-    "sshr v27.4s, v27.4s, #0x1f\n"
-    "smin v17.4s, v17.4s, v15.4s\n"
-    "smin v18.4s, v18.4s, v15.4s\n"
-    "sqadd v21.4s, v21.4s, v27.4s\n"
-    "smax v17.4s, v17.4s, v16.4s\n"
-    "smax v18.4s, v18.4s, v16.4s\n"
-    "srshl v21.4s, v21.4s, v9.4s\n"
-    "uzp1 v11.16b, v11.16b, v17.16b\n"
-    "uzp1 v11.16b, v11.16b, v11.16b\n"
-    "add v21.4s, v21.4s, v14.4s\n"
-    "smin v21.4s, v21.4s, v15.4s\n"
-    "smax v21.4s, v21.4s, v16.4s\n"
-    "uzp1 v18.16b, v18.16b, v21.16b\n"
-    "uzp1 v18.16b, v18.16b, v18.16b\n"
-    "tbz x3, #2, 85f\n"
-    "st1 { v13.s }[0], [x15], #0x4\n"
-    "st1 { v19.s }[0], [x14], #0x4\n"
-    "st1 { v11.s }[0], [x13], #0x4\n"
-    "st1 { v18.s }[0], [x12], #0x4\n"
-    "tbz x3, #1, 84f\n"
-    "st1 { v13.h }[2], [x15], #0x2\n"
-    "st1 { v19.h }[2], [x14], #0x2\n"
-    "st1 { v11.h }[2], [x13], #0x2\n"
-    "st1 { v18.h }[2], [x12], #0x2\n"
-    "tbz x3, #0, 87f\n"
-    "st1 { v13.b }[6], [x15], #0x1\n"
-    "st1 { v19.b }[6], [x14], #0x1\n"
-    "st1 { v11.b }[6], [x13], #0x1\n"
-    "st1 { v18.b }[6], [x12], #0x1\n"
+    "srshl v15.4s, v15.4s, v0.4s\n"
+    "sqadd v10.4s, v10.4s, v19.4s\n"
+    "srshl v9.4s, v9.4s, v0.4s\n"
+    "sqadd v16.4s, v16.4s, v4.4s\n"
+    "srshl v22.4s, v22.4s, v0.4s\n"
+    "sqadd v21.4s, v21.4s, v5.4s\n"
+    "srshl v23.4s, v23.4s, v0.4s\n"
+    "sqadd v18.4s, v18.4s, v26.4s\n"
+    "srshl v10.4s, v10.4s, v31.4s\n"
+    "sqxtn v15.4h, v15.4s\n"
+    "srshl v16.4s, v16.4s, v31.4s\n"
+    "sqxtn v9.4h, v9.4s\n"
+    "srshl v21.4s, v21.4s, v31.4s\n"
+    "sqxtn v22.4h, v22.4s\n"
+    "srshl v18.4s, v18.4s, v31.4s\n"
+    "sqxtn v23.4h, v23.4s\n"
+    "sqxtn2 v15.8h, v10.4s\n"
+    "sqxtn2 v9.8h, v16.4s\n"
+    "sqxtn2 v22.8h, v21.4s\n"
+    "sqxtn2 v23.8h, v18.4s\n"
+    "sqadd v15.8h, v15.8h, v11.8h\n"
+    "sqadd v9.8h, v9.8h, v11.8h\n"
+    "sqadd v22.8h, v22.8h, v11.8h\n"
+    "sqadd v23.8h, v23.8h, v11.8h\n"
+    "smax v15.8h, v15.8h, v17.8h\n"
+    "smax v9.8h, v9.8h, v17.8h\n"
+    "smax v22.8h, v22.8h, v17.8h\n"
+    "smax v23.8h, v23.8h, v17.8h\n"
+    "smin v15.8h, v15.8h, v14.8h\n"
+    "smin v9.8h, v9.8h, v14.8h\n"
+    "smin v22.8h, v22.8h, v14.8h\n"
+    "smin v23.8h, v23.8h, v14.8h\n"
+    "uzp1 v15.16b, v15.16b, v15.16b\n"
+    "uzp1 v9.16b, v9.16b, v9.16b\n"
+    "uzp1 v22.16b, v22.16b, v22.16b\n"
+    "uzp1 v23.16b, v23.16b, v23.16b\n"
+    "tbz x8, #2, 85f\n"
+    "st1 { v15.s }[0], [x10], #0x4\n"
+    "st1 { v9.s }[0], [x9], #0x4\n"
+    "st1 { v22.s }[0], [x28], #0x4\n"
+    "st1 { v23.s }[0], [x27], #0x4\n"
+    "tbz x8, #1, 84f\n"
+    "st1 { v15.h }[2], [x10], #0x2\n"
+    "st1 { v9.h }[2], [x9], #0x2\n"
+    "st1 { v22.h }[2], [x28], #0x2\n"
+    "st1 { v23.h }[2], [x27], #0x2\n"
+    "tbz x8, #0, 87f\n"
+    "st1 { v15.b }[6], [x10], #0x1\n"
+    "st1 { v9.b }[6], [x9], #0x1\n"
+    "st1 { v22.b }[6], [x28], #0x1\n"
+    "st1 { v23.b }[6], [x27], #0x1\n"
     "b 87f\n"
     "84:"  // Oddments: Bit 2: Bit 1: Unset
-    "tbz x3, #0, 87f\n"
-    "st1 { v13.b }[4], [x15], #0x1\n"
-    "st1 { v19.b }[4], [x14], #0x1\n"
-    "st1 { v11.b }[4], [x13], #0x1\n"
-    "st1 { v18.b }[4], [x12], #0x1\n"
+    "tbz x8, #0, 87f\n"
+    "st1 { v15.b }[4], [x10], #0x1\n"
+    "st1 { v9.b }[4], [x9], #0x1\n"
+    "st1 { v22.b }[4], [x28], #0x1\n"
+    "st1 { v23.b }[4], [x27], #0x1\n"
     "b 87f\n"
     "85:"  // Oddments: Bit 2: Unset
-    "tbz x3, #1, 86f\n"
-    "st1 { v13.h }[0], [x15], #0x2\n"
-    "st1 { v19.h }[0], [x14], #0x2\n"
-    "st1 { v11.h }[0], [x13], #0x2\n"
-    "st1 { v18.h }[0], [x12], #0x2\n"
-    "tbz x3, #0, 87f\n"
-    "st1 { v13.b }[2], [x15], #0x1\n"
-    "st1 { v19.b }[2], [x14], #0x1\n"
-    "st1 { v11.b }[2], [x13], #0x1\n"
-    "st1 { v18.b }[2], [x12], #0x1\n"
+    "tbz x8, #1, 86f\n"
+    "st1 { v15.h }[0], [x10], #0x2\n"
+    "st1 { v9.h }[0], [x9], #0x2\n"
+    "st1 { v22.h }[0], [x28], #0x2\n"
+    "st1 { v23.h }[0], [x27], #0x2\n"
+    "tbz x8, #0, 87f\n"
+    "st1 { v15.b }[2], [x10], #0x1\n"
+    "st1 { v9.b }[2], [x9], #0x1\n"
+    "st1 { v22.b }[2], [x28], #0x1\n"
+    "st1 { v23.b }[2], [x27], #0x1\n"
     "b 87f\n"
     "86:"  // Oddments: Bit 2: Unset: Bit 1: Unset
-    "tbz x3, #0, 87f\n"
-    "st1 { v13.b }[0], [x15], #0x1\n"
-    "st1 { v19.b }[0], [x14], #0x1\n"
-    "st1 { v11.b }[0], [x13], #0x1\n"
-    "st1 { v18.b }[0], [x12], #0x1\n"
+    "tbz x8, #0, 87f\n"
+    "st1 { v15.b }[0], [x10], #0x1\n"
+    "st1 { v9.b }[0], [x9], #0x1\n"
+    "st1 { v22.b }[0], [x28], #0x1\n"
+    "st1 { v23.b }[0], [x27], #0x1\n"
     "87:"  // Oddments: Bit 2: End
-
     "88:"  // End
-
     :
     : [offsetof_Params_bias] "I" (offsetof(Params, bias)), [offsetof_Params_inptrs] "I" (offsetof(Params, inptrs)), [offsetof_Params_n_channels] "I" (offsetof(Params, n_channels)), [offsetof_Params_outptrs] "I" (offsetof(Params, outptrs)), [offsetof_Params_requant] "I" (offsetof(Params, requant)), [offsetof_Params_requant_muls] "I" (offsetof(Params, requant_muls)), [offsetof_Params_requant_shifts] "I" (offsetof(Params, requant_shifts)), [offsetof_Params_weights] "I" (offsetof(Params, weights)), [offsetof_Requantize32_a_offset] "I" (offsetof(arm_gemm::Requantize32, a_offset)), [offsetof_Requantize32_b_offset] "I" (offsetof(arm_gemm::Requantize32, b_offset)), [offsetof_Requantize32_c_offset] "I" (offsetof(arm_gemm::Requantize32, c_offset)), [offsetof_Requantize32_maxval] "I" (offsetof(arm_gemm::Requantize32, maxval)), [offsetof_Requantize32_minval] "I" (offsetof(arm_gemm::Requantize32, minval)), [params] "r" (&params)
-    : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17", "x19", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"
+    : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x8", "x9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17", "x19", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"
   );
 }
 
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8s8u8q_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8s8u8q_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp
index d3d5000..ea70b56 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8s8u8q_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8s8u8q_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -36,37 +36,24 @@
 
 void a64_u8s8u8q_nhwc_5x5_s1_output2x2_mla_depthfirst_impl(unsigned int, const uint8_t *const *, const int8_t *, const int32_t *, const arm_gemm::Requantize32 &, const int32_t *, const int32_t *, uint8_t *const *);
 
-struct a64_u8s8u8q_nhwc_5x5_s1_output2x2_mla_depthfirst
+class a64_u8s8u8q_nhwc_5x5_s1_output2x2_mla_depthfirst : public DepthwiseDepthfirstStrategy<uint8_t, int8_t, uint8_t, int32_t>
 {
-  typedef int32_t bias_type;
-  typedef uint8_t input_type;
-  typedef int8_t weight_type;
-  typedef uint8_t return_type;
+  using Parent = DepthwiseDepthfirstStrategy<uint8_t, int8_t, uint8_t, int32_t>;
 
-  constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::None;
-
-  typedef void (*kern_type)(unsigned int, const uint8_t *const *, const int8_t *, const int32_t *, const arm_gemm::Requantize32 &, const int32_t *, const int32_t *, uint8_t *const *);
-  typedef void (*parameter_packing_fn)(unsigned int, void *, const int8_t *, size_t, size_t);
-  typedef size_t (*parameter_sizing_fn)(const DepthwiseArgs &);
-
+  public:
   constexpr static unsigned int kernel_rows = 5;
   constexpr static unsigned int kernel_cols = 5;
 
   constexpr static unsigned int stride_rows = 1;
   constexpr static unsigned int stride_cols = 1;
 
-  constexpr static unsigned int output_rows = 2;
-  constexpr static unsigned int output_cols = 2;
+  a64_u8s8u8q_nhwc_5x5_s1_output2x2_mla_depthfirst(const CPUInfo *) : Parent(2, 2, 5, 5, 1, 1) {}
 
-  constexpr static unsigned int input_rows = 6;
-  constexpr static unsigned int input_cols = 6;
+  arm_gemm::VLType get_vl_type(void) const override { return arm_gemm::VLType::None; }
 
-  constexpr static parameter_packing_fn pack_parameters = interleave_a64_s8q_5x5_mla::pack_parameters;
-  constexpr static parameter_sizing_fn get_packed_size = interleave_a64_s8q_5x5_mla::get_packed_size;
-
-  kern_type kernel = a64_u8s8u8q_nhwc_5x5_s1_output2x2_mla_depthfirst_impl;
-
-  a64_u8s8u8q_nhwc_5x5_s1_output2x2_mla_depthfirst(const CPUInfo *) {}
+  Parent::KernelType kernel = a64_u8s8u8q_nhwc_5x5_s1_output2x2_mla_depthfirst_impl;
+  Parent::KernelType get_kernel(void) const override { return kernel; }
+  unsigned int get_accumulator_depth_vl(void) const override { return 2; }
 };
 
 }  // namespace depthwise
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8s8u8q_nhwc_5x5_s1_output2x2_mla_depthfirst/generic.cpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8s8u8q_nhwc_5x5_s1_output2x2_mla_depthfirst/generic.cpp
index 9715613..291ffec 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8s8u8q_nhwc_5x5_s1_output2x2_mla_depthfirst/generic.cpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8s8u8q_nhwc_5x5_s1_output2x2_mla_depthfirst/generic.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -46,7 +46,7 @@
   struct Params
   {
     long unsigned int n_channels;
-    const int8_t *weights;
+    const void *weights;
     const int32_t *bias;
     const arm_gemm::Requantize32 *requant;
     const int32_t *const requant_muls;
@@ -57,7 +57,7 @@
     Params(
       long unsigned int n_channels,
       const uint8_t *const *inptrs_raw,
-      const int8_t *const weights,
+      const void *const weights,
       const int32_t *const bias,
       const arm_gemm::Requantize32 &qp,
       const int32_t *const requant_muls,
@@ -111,2096 +111,2070 @@
                       requant_muls, requant_shifts, outptrs);
 
   __asm__ __volatile__(
-    "ldr x4, [%x[params], %[offsetof_Params_n_channels]]\n"
-    "mov x10, #0x0\n"
-    "ldr x3, [%x[params], %[offsetof_Params_weights]]\n"
-    "mov x1, #0x0\n"
-    "ldr x22, [%x[params], %[offsetof_Params_requant]]\n"
-    "add x25, %x[params], %[offsetof_Params_inptrs]\n"
-    "ldr x2, [%x[params], %[offsetof_Params_requant_muls]]\n"
-    "lsr x19, x4, #0x3\n"
-    "ldr x5, [%x[params], %[offsetof_Params_requant_shifts]]\n"
-    "add x13, x22, %[offsetof_Requantize32_a_offset]\n"
-    "ldr x21, [%x[params], %[offsetof_Params_outptrs]]\n"
-    "add x20, x22, %[offsetof_Requantize32_b_offset]\n"
-    "ld1r { v9.16b }, [x13]\n"
-    "add x8, x22, %[offsetof_Requantize32_c_offset]\n"
-    "ld1r { v14.16b }, [x20]\n"
-    "add x20, x22, %[offsetof_Requantize32_minval]\n"
-    "ld1r { v10.4s }, [x8]\n"
-    "add x8, x22, %[offsetof_Requantize32_maxval]\n"
-    "ld1r { v11.4s }, [x20]\n"
-    "ld1r { v13.4s }, [x8]\n"
-    "ldp x17, x16, [x21, #0x0]\n"
-    "ldp x6, x8, [x21, #0x10]\n"
-    "cbz x19, 3f\n"
-    "subs x19, x19, #0x1\n"
-    "ldr x12, [%x[params], %[offsetof_Params_bias]]\n"
-    "ldr q15, [x12, #0x0]\n"
-    "mov v16.16b, v15.16b\n"
-    "ldr q18, [x12, #0x10]\n"
-    "add x12, x12, #0x20\n"
-    "mov v7.16b, v15.16b\n"
-    "str x12, [%x[params], %[offsetof_Params_bias]]\n"
-    "mov v8.16b, v15.16b\n"
-    "ldr d0, [x3, #0x0]\n"
-    "ldr d1, [x3, #0x8]\n"
-    "mov v21.16b, v18.16b\n"
-    "ldr d2, [x3, #0x10]\n"
-    "mov v17.16b, v18.16b\n"
-    "ldr d3, [x3, #0x18]\n"
-    "mov v5.16b, v18.16b\n"
-    "ldr d4, [x3, #0x20]\n"
+    "ldr x10, [%x[params], %[offsetof_Params_requant]]\n"
+    "ldr x0, [%x[params], %[offsetof_Params_n_channels]]\n"
+    "add x17, x10, %[offsetof_Requantize32_a_offset]\n"
+    "add x9, x10, %[offsetof_Requantize32_b_offset]\n"
+    "ldr x25, [%x[params], %[offsetof_Params_outptrs]]\n"
+    "add x4, x10, %[offsetof_Requantize32_c_offset]\n"
+    "add x14, x10, %[offsetof_Requantize32_minval]\n"
+    "ldr x23, [%x[params], %[offsetof_Params_weights]]\n"
+    "add x5, x10, %[offsetof_Requantize32_maxval]\n"
+    "ld1r { v9.16b }, [x17]\n"
+    "ld1r { v14.16b }, [x9]\n"
+    "lsr x3, x0, #0x3\n"
+    "ld1r { v18.8h }, [x4]\n"
+    "ld1r { v11.8h }, [x14]\n"
+    "mov x24, #0x0\n"
+    "mov x22, #0x0\n"
+    "ld1r { v13.8h }, [x5]\n"
+    "ldr x10, [%x[params], %[offsetof_Params_requant_muls]]\n"
+    "add x20, %x[params], %[offsetof_Params_inptrs]\n"
+    "ldr x1, [%x[params], %[offsetof_Params_requant_shifts]]\n"
+    "ldp x16, x8, [x25, #0x0]\n"
+    "ldp x4, x7, [x25, #0x10]\n"
+    "cbz x3, 3f\n"
+    "ldr x19, [%x[params], %[offsetof_Params_bias]]\n"
+    "ldr q15, [x19, #0x0]\n"
+    "subs x3, x3, #0x1\n"
+    "mov v17.16b, v15.16b\n"
+    "ldr q16, [x19, #0x10]\n"
+    "add x19, x19, #0x20\n"
+    "str x19, [%x[params], %[offsetof_Params_bias]]\n"
+    "ldr d0, [x23, #0x0]\n"
+    "ldr d1, [x23, #0x8]\n"
+    "ldr d2, [x23, #0x10]\n"
+    "mov v8.16b, v16.16b\n"
+    "mov v10.16b, v15.16b\n"
+    "ldr d3, [x23, #0x18]\n"
+    "ldr d4, [x23, #0x20]\n"
+    "mov v7.16b, v16.16b\n"
+    "mov v6.16b, v15.16b\n"
+    "ldp x28, x6, [x20, #0x0]\n"
+    "ldp x26, x25, [x20, #0x10]\n"
+    "mov v5.16b, v16.16b\n"
     "ssubl v0.8h, v0.8b, v14.8b\n"
-    "ldp x28, x27, [x25, #0x0]\n"
+    "ldp x5, x2, [x20, #0x20]\n"
+    "ldp x27, x21, [x20, #0x30]\n"
     "ssubl v1.8h, v1.8b, v14.8b\n"
-    "ldp x26, x13, [x25, #0x10]\n"
     "ssubl v2.8h, v2.8b, v14.8b\n"
+    "ldp x12, x19, [x20, #0x40]\n"
+    "ldr d31, [x28, x24]\n"
     "ssubl v3.8h, v3.8b, v14.8b\n"
-    "ldp x24, x23, [x25, #0x20]\n"
     "ssubl v4.8h, v4.8b, v14.8b\n"
-    "ldp x22, x21, [x25, #0x30]\n"
-    "ldp x20, x0, [x25, #0x40]\n"
-    "ldr d31, [x28, x10]\n"
+    "ldr d30, [x6, x24]\n"
+    "ldr d29, [x26, x24]\n"
     "usubl v31.8h, v31.8b, v9.8b\n"
-    "ldr d30, [x27, x10]\n"
-    "ldr d29, [x26, x10]\n"
     "usubl v30.8h, v30.8b, v9.8b\n"
-    "ldr d28, [x13, x10]\n"
+    "ldr d28, [x25, x24]\n"
+    "ldr d27, [x5, x24]\n"
     "usubl v29.8h, v29.8b, v9.8b\n"
-    "ldr d27, [x24, x10]\n"
-    "ldr d23, [x23, x10]\n"
     "usubl v28.8h, v28.8b, v9.8b\n"
-    "ldr d25, [x22, x10]\n"
-    "ldr d24, [x21, x10]\n"
+    "ldr d23, [x2, x24]\n"
+    "ldr d25, [x27, x24]\n"
     "usubl v27.8h, v27.8b, v9.8b\n"
     "usubl v23.8h, v23.8b, v9.8b\n"
-    "ldr d26, [x20, x10]\n"
-    "ldr d22, [x0, x10]\n"
+    "ldr d24, [x21, x24]\n"
+    "ldr d26, [x12, x24]\n"
     "usubl v25.8h, v25.8b, v9.8b\n"
     "usubl v24.8h, v24.8b, v9.8b\n"
+    "ldr d22, [x19, x24]\n"
     "usubl v26.8h, v26.8b, v9.8b\n"
     "usubl v22.8h, v22.8b, v9.8b\n"
     "beq 2f\n"
     "1:"  // Loop
     "smlal v15.4s, v31.4h, v0.4h\n"
-    "ldr x20, [x25, #0x50]\n"
-    "subs x19, x19, #0x1\n"
-    "smlal2 v18.4s, v31.8h, v0.8h\n"
-    "ldr x28, [x25, #0x58]\n"
-    "smlal v16.4s, v30.4h, v0.4h\n"
-    "ldr x0, [x25, #0x60]\n"
-    "smlal2 v21.4s, v30.8h, v0.8h\n"
-    "ldr d31, [x20, x10]\n"
-    "smlal v7.4s, v29.4h, v0.4h\n"
-    "ldr x7, [x25, #0x68]\n"
-    "smlal2 v17.4s, v29.8h, v0.8h\n"
-    "ldr x26, [x25, #0x70]\n"
-    "smlal v8.4s, v28.4h, v0.4h\n"
-    "ldr x23, [x25, #0x78]\n"
+    "smlal2 v16.4s, v31.8h, v0.8h\n"
+    "ldr x19, [x20, #0x50]\n"
+    "ldr d31, [x19, x24]\n"
+    "smlal v17.4s, v30.4h, v0.4h\n"
+    "smlal v10.4s, v29.4h, v0.4h\n"
+    "ldr x15, [x20, #0x58]\n"
+    "usubl v31.8h, v31.8b, v9.8b\n"
+    "smlal v6.4s, v28.4h, v0.4h\n"
+    "smlal2 v8.4s, v30.8h, v0.8h\n"
+    "ldr x19, [x20, #0x60]\n"
+    "ldr x27, [x20, #0x68]\n"
+    "smlal2 v7.4s, v29.8h, v0.8h\n"
+    "smlal v15.4s, v30.4h, v1.4h\n"
+    "ldr x5, [x20, #0x70]\n"
+    "ldr x11, [x20, #0x78]\n"
+    "smlal2 v16.4s, v30.8h, v1.8h\n"
     "smlal2 v5.4s, v28.8h, v0.8h\n"
-    "ldr d0, [x3, #0x28]\n"
-    "smlal v15.4s, v30.4h, v1.4h\n"
-    "ldr x20, [x25, #0x80]\n"
-    "smlal2 v18.4s, v30.8h, v1.8h\n"
-    "ldr d30, [x28, x10]\n"
-    "smlal v16.4s, v27.4h, v1.4h\n"
-    "ldr x22, [x25, #0x88]\n"
-    "smlal2 v21.4s, v27.8h, v1.8h\n"
-    "ldr x13, [x25, #0x90]\n"
-    "smlal v7.4s, v28.4h, v1.4h\n"
-    "ldr x21, [x25, #0x98]\n"
-    "smlal2 v17.4s, v28.8h, v1.8h\n"
-    "ldr x14, [x25, #0xa0]\n"
-    "smlal v8.4s, v23.4h, v1.4h\n"
-    "ldr x11, [x25, #0xa8]\n"
-    "smlal2 v5.4s, v23.8h, v1.8h\n"
-    "ldr d1, [x3, #0x30]\n"
+    "ldr d30, [x15, x24]\n"
+    "usubl v30.8h, v30.8b, v9.8b\n"
+    "smlal v17.4s, v27.4h, v1.4h\n"
+    "smlal v10.4s, v28.4h, v1.4h\n"
+    "ldr d0, [x23, #0x28]\n"
+    "ssubl v0.8h, v0.8b, v14.8b\n"
+    "smlal v6.4s, v23.4h, v1.4h\n"
+    "smlal2 v8.4s, v27.8h, v1.8h\n"
+    "ldr x12, [x20, #0x80]\n"
+    "ldr x26, [x20, #0x88]\n"
+    "smlal2 v7.4s, v28.8h, v1.8h\n"
     "smlal v15.4s, v27.4h, v2.4h\n"
-    "ldr x24, [x25, #0xb0]\n"
-    "smlal2 v18.4s, v27.8h, v2.8h\n"
-    "ldr d27, [x0, x10]\n"
-    "smlal v16.4s, v25.4h, v2.4h\n"
-    "ldr x0, [x25, #0xb8]\n"
-    "smlal2 v21.4s, v25.8h, v2.8h\n"
-    "ldr x15, [x25, #0xc0]\n"
-    "smlal v7.4s, v23.4h, v2.4h\n"
-    "ldr x9, [x25, #0xc8]\n"
-    "smlal2 v17.4s, v23.8h, v2.8h\n"
-    "ldr x27, [x25, #0xd0]\n"
-    "usubl v31.8h, v31.8b, v9.8b\n"
-    "ldr x28, [x25, #0xd8]\n"
+    "ldr x14, [x20, #0x90]\n"
+    "ldr x15, [x20, #0x98]\n"
+    "smlal2 v16.4s, v27.8h, v2.8h\n"
+    "smlal2 v5.4s, v23.8h, v1.8h\n"
+    "ldr d27, [x19, x24]\n"
+    "usubl v27.8h, v27.8b, v9.8b\n"
+    "smlal v17.4s, v25.4h, v2.4h\n"
+    "smlal v10.4s, v23.4h, v2.4h\n"
+    "ldr d1, [x23, #0x30]\n"
+    "ssubl v1.8h, v1.8b, v14.8b\n"
+    "smlal v6.4s, v31.4h, v2.4h\n"
+    "smlal2 v8.4s, v25.8h, v2.8h\n"
+    "ldr x21, [x20, #0xa0]\n"
+    "ldr x2, [x20, #0xa8]\n"
+    "smlal2 v7.4s, v23.8h, v2.8h\n"
     "smlal v15.4s, v25.4h, v3.4h\n"
-    "ldr q6, [x2, #0x0]\n"
-    "smlal2 v18.4s, v25.8h, v3.8h\n"
-    "ldr d25, [x7, x10]\n"
-    "smlal v8.4s, v31.4h, v2.4h\n"
-    "ldr x12, [x25, #0xe0]\n"
+    "ldr x13, [x20, #0xb0]\n"
+    "ldr x9, [x20, #0xb8]\n"
+    "smlal2 v16.4s, v25.8h, v3.8h\n"
     "smlal2 v5.4s, v31.8h, v2.8h\n"
-    "ldr d2, [x3, #0x38]\n"
-    "smlal v16.4s, v24.4h, v3.4h\n"
-    "ldr q19, [x5, #0x0]\n"
-    "smlal2 v21.4s, v24.8h, v3.8h\n"
-    "ldr q20, [x2, #0x10]\n"
-    "add x2, x2, #0x20\n"
-    "smlal v7.4s, v31.4h, v3.4h\n"
-    "ldr q12, [x5, #0x10]\n"
-    "add x5, x5, #0x20\n"
-    "smlal2 v17.4s, v31.8h, v3.8h\n"
-    "usubl v30.8h, v30.8b, v9.8b\n"
+    "ldr d25, [x27, x24]\n"
+    "usubl v25.8h, v25.8b, v9.8b\n"
+    "smlal v17.4s, v24.4h, v3.4h\n"
+    "smlal v10.4s, v31.4h, v3.4h\n"
+    "ldr d2, [x23, #0x38]\n"
+    "ssubl v2.8h, v2.8b, v14.8b\n"
+    "smlal v6.4s, v30.4h, v3.4h\n"
+    "smlal2 v8.4s, v24.8h, v3.8h\n"
+    "ldr x19, [x20, #0xc0]\n"
+    "ldr x28, [x20, #0xc8]\n"
+    "smlal2 v7.4s, v31.8h, v3.8h\n"
     "smlal v15.4s, v24.4h, v4.4h\n"
-    "smlal2 v18.4s, v24.8h, v4.8h\n"
-    "ldr d24, [x26, x10]\n"
-    "smlal v8.4s, v30.4h, v3.4h\n"
-    "ldr x7, [x25, #0xe8]\n"
+    "ldr x6, [x20, #0xd0]\n"
+    "ldr x27, [x20, #0xd8]\n"
+    "smlal2 v16.4s, v24.8h, v4.8h\n"
     "smlal2 v5.4s, v30.8h, v3.8h\n"
-    "ldr d3, [x3, #0x40]\n"
+    "ldr d24, [x5, x24]\n"
+    "usubl v24.8h, v24.8b, v9.8b\n"
+    "smlal v17.4s, v27.4h, v4.4h\n"
+    "smlal v10.4s, v30.4h, v4.4h\n"
+    "ldr d3, [x23, #0x40]\n"
+    "ssubl v3.8h, v3.8b, v14.8b\n"
+    "smlal v6.4s, v26.4h, v4.4h\n"
+    "smlal2 v8.4s, v27.8h, v4.8h\n"
+    "ldr d27, [x11, x24]\n"
     "usubl v27.8h, v27.8b, v9.8b\n"
-    "smlal v7.4s, v30.4h, v4.4h\n"
-    "smlal2 v17.4s, v30.8h, v4.8h\n"
-    "smlal v16.4s, v27.4h, v4.4h\n"
-    "smlal2 v21.4s, v27.8h, v4.8h\n"
-    "ldr d27, [x23, x10]\n"
-    "smlal v8.4s, v26.4h, v4.4h\n"
-    "ldr x26, [x25, #0xf0]\n"
-    "smlal2 v5.4s, v26.8h, v4.8h\n"
-    "ldr d4, [x3, #0x48]\n"
-    "ssubl v0.8h, v0.8b, v14.8b\n"
-    "usubl v25.8h, v25.8b, v9.8b\n"
-    "ssubl v1.8h, v1.8b, v14.8b\n"
+    "smlal2 v7.4s, v30.8h, v4.8h\n"
     "smlal v15.4s, v29.4h, v0.4h\n"
-    "smlal2 v18.4s, v29.8h, v0.8h\n"
-    "smlal v16.4s, v28.4h, v0.4h\n"
-    "smlal2 v21.4s, v28.8h, v0.8h\n"
-    "smlal v7.4s, v22.4h, v0.4h\n"
-    "smlal2 v17.4s, v22.8h, v0.8h\n"
-    "smlal v8.4s, v25.4h, v0.4h\n"
-    "smlal2 v5.4s, v25.8h, v0.8h\n"
-    "ldr d0, [x3, #0x50]\n"
-    "smlal v15.4s, v28.4h, v1.4h\n"
-    "smlal2 v18.4s, v28.8h, v1.8h\n"
-    "ldr d28, [x22, x10]\n"
-    "smlal v16.4s, v23.4h, v1.4h\n"
-    "ldr x23, [x25, #0xf8]\n"
-    "smlal2 v21.4s, v23.8h, v1.8h\n"
-    "smlal v7.4s, v25.4h, v1.4h\n"
-    "smlal2 v17.4s, v25.8h, v1.8h\n"
-    "usubl v24.8h, v24.8b, v9.8b\n"
-    "ssubl v2.8h, v2.8b, v14.8b\n"
-    "usubl v27.8h, v27.8b, v9.8b\n"
-    "smlal v8.4s, v24.4h, v1.4h\n"
-    "smlal2 v5.4s, v24.8h, v1.8h\n"
-    "ldr d1, [x3, #0x58]\n"
-    "smlal v15.4s, v23.4h, v2.4h\n"
-    "smlal2 v18.4s, v23.8h, v2.8h\n"
-    "ldr d23, [x20, x10]\n"
-    "smlal v16.4s, v31.4h, v2.4h\n"
-    "ldr x22, [x25, #0x100]\n"
-    "smlal2 v21.4s, v31.8h, v2.8h\n"
-    "smlal v7.4s, v24.4h, v2.4h\n"
-    "smlal2 v17.4s, v24.8h, v2.8h\n"
-    "smlal v8.4s, v27.4h, v2.4h\n"
-    "smlal2 v5.4s, v27.8h, v2.8h\n"
-    "ldr d2, [x3, #0x60]\n"
-    "ssubl v3.8h, v3.8b, v14.8b\n"
-    "usubl v23.8h, v23.8b, v9.8b\n"
-    "ssubl v4.8h, v4.8b, v14.8b\n"
-    "smlal v15.4s, v31.4h, v3.4h\n"
-    "smlal2 v18.4s, v31.8h, v3.8h\n"
-    "ldr d31, [x13, x10]\n"
-    "smlal v16.4s, v30.4h, v3.4h\n"
-    "ldr x20, [x25, #0x108]\n"
-    "smlal2 v21.4s, v30.8h, v3.8h\n"
-    "smlal v7.4s, v27.4h, v3.4h\n"
-    "smlal2 v17.4s, v27.8h, v3.8h\n"
-    "smlal v8.4s, v23.4h, v3.4h\n"
-    "smlal2 v5.4s, v23.8h, v3.8h\n"
-    "ldr d3, [x3, #0x68]\n"
-    "smlal v15.4s, v30.4h, v4.4h\n"
-    "smlal2 v18.4s, v30.8h, v4.8h\n"
-    "ldr d30, [x21, x10]\n"
-    "smlal v16.4s, v26.4h, v4.4h\n"
-    "ldr x13, [x25, #0x110]\n"
-    "smlal2 v21.4s, v26.8h, v4.8h\n"
-    "ldr d26, [x14, x10]\n"
-    "smlal v7.4s, v23.4h, v4.4h\n"
-    "ldr x21, [x25, #0x118]\n"
-    "smlal2 v17.4s, v23.8h, v4.8h\n"
-    "usubl v28.8h, v28.8b, v9.8b\n"
-    "ssubl v0.8h, v0.8b, v14.8b\n"
-    "usubl v31.8h, v31.8b, v9.8b\n"
-    "smlal v8.4s, v28.4h, v4.4h\n"
-    "smlal2 v5.4s, v28.8h, v4.8h\n"
-    "ldr d4, [x3, #0x70]\n"
-    "smlal v15.4s, v22.4h, v0.4h\n"
-    "smlal2 v18.4s, v22.8h, v0.8h\n"
-    "ldr d22, [x0, x10]\n"
-    "smlal v16.4s, v25.4h, v0.4h\n"
-    "smlal2 v21.4s, v25.8h, v0.8h\n"
-    "smlal v7.4s, v31.4h, v0.4h\n"
-    "smlal2 v17.4s, v31.8h, v0.8h\n"
-    "usubl v30.8h, v30.8b, v9.8b\n"
-    "ssubl v1.8h, v1.8b, v14.8b\n"
-    "usubl v26.8h, v26.8b, v9.8b\n"
-    "smlal v8.4s, v30.4h, v0.4h\n"
-    "smlal2 v5.4s, v30.8h, v0.8h\n"
-    "ldr d0, [x3, #0x78]\n"
-    "smlal v15.4s, v25.4h, v1.4h\n"
-    "smlal2 v18.4s, v25.8h, v1.8h\n"
-    "ldr d25, [x11, x10]\n"
-    "smlal v16.4s, v24.4h, v1.4h\n"
-    "smlal2 v21.4s, v24.8h, v1.8h\n"
-    "smlal v7.4s, v30.4h, v1.4h\n"
-    "smlal2 v17.4s, v30.8h, v1.8h\n"
-    "smlal v8.4s, v26.4h, v1.4h\n"
-    "smlal2 v5.4s, v26.8h, v1.8h\n"
-    "ldr d1, [x3, #0x80]\n"
-    "ssubl v2.8h, v2.8b, v14.8b\n"
-    "usubl v25.8h, v25.8b, v9.8b\n"
-    "ssubl v3.8h, v3.8b, v14.8b\n"
-    "smlal v15.4s, v24.4h, v2.4h\n"
-    "smlal2 v18.4s, v24.8h, v2.8h\n"
-    "ldr d24, [x24, x10]\n"
-    "smlal v16.4s, v27.4h, v2.4h\n"
-    "smlal2 v21.4s, v27.8h, v2.8h\n"
-    "smlal v7.4s, v26.4h, v2.4h\n"
-    "smlal2 v17.4s, v26.8h, v2.8h\n"
-    "smlal v8.4s, v25.4h, v2.4h\n"
-    "smlal2 v5.4s, v25.8h, v2.8h\n"
-    "ldr d2, [x3, #0x88]\n"
-    "smlal v15.4s, v27.4h, v3.4h\n"
-    "smlal2 v18.4s, v27.8h, v3.8h\n"
-    "ldr d27, [x15, x10]\n"
-    "smlal v16.4s, v23.4h, v3.4h\n"
-    "smlal2 v21.4s, v23.8h, v3.8h\n"
-    "smlal v7.4s, v25.4h, v3.4h\n"
-    "smlal2 v17.4s, v25.8h, v3.8h\n"
-    "usubl v24.8h, v24.8b, v9.8b\n"
-    "ssubl v4.8h, v4.8b, v14.8b\n"
-    "usubl v22.8h, v22.8b, v9.8b\n"
-    "smlal v8.4s, v24.4h, v3.4h\n"
-    "smlal2 v5.4s, v24.8h, v3.8h\n"
-    "ldr d3, [x3, #0x90]\n"
-    "smlal v15.4s, v23.4h, v4.4h\n"
-    "smlal2 v18.4s, v23.8h, v4.8h\n"
-    "ldr d23, [x9, x10]\n"
-    "smlal v16.4s, v28.4h, v4.4h\n"
-    "smlal2 v21.4s, v28.8h, v4.8h\n"
-    "ldr d28, [x12, x10]\n"
-    "smlal v7.4s, v24.4h, v4.4h\n"
-    "smlal2 v17.4s, v24.8h, v4.8h\n"
-    "smlal v8.4s, v22.4h, v4.4h\n"
-    "smlal2 v5.4s, v22.8h, v4.8h\n"
-    "ldr d4, [x3, #0x98]\n"
-    "ssubl v0.8h, v0.8b, v14.8b\n"
-    "usubl v27.8h, v27.8b, v9.8b\n"
-    "usubl v23.8h, v23.8b, v9.8b\n"
-    "smlal v15.4s, v31.4h, v0.4h\n"
-    "smlal2 v18.4s, v31.8h, v0.8h\n"
-    "ldr d31, [x27, x10]\n"
-    "smlal v16.4s, v30.4h, v0.4h\n"
-    "smlal2 v21.4s, v30.8h, v0.8h\n"
-    "smlal v7.4s, v27.4h, v0.4h\n"
-    "smlal2 v17.4s, v27.8h, v0.8h\n"
-    "smlal v8.4s, v23.4h, v0.4h\n"
-    "smlal2 v5.4s, v23.8h, v0.8h\n"
-    "ldr d0, [x3, #0xa0]\n"
-    "ssubl v1.8h, v1.8b, v14.8b\n"
-    "usubl v31.8h, v31.8b, v9.8b\n"
-    "ssubl v2.8h, v2.8b, v14.8b\n"
-    "smlal v15.4s, v30.4h, v1.4h\n"
-    "smlal2 v18.4s, v30.8h, v1.8h\n"
-    "ldr d30, [x28, x10]\n"
-    "smlal v16.4s, v26.4h, v1.4h\n"
-    "smlal2 v21.4s, v26.8h, v1.8h\n"
-    "smlal v7.4s, v23.4h, v1.4h\n"
-    "smlal2 v17.4s, v23.8h, v1.8h\n"
-    "smlal v8.4s, v31.4h, v1.4h\n"
-    "smlal2 v5.4s, v31.8h, v1.8h\n"
-    "ldr d1, [x3, #0xa8]\n"
-    "smlal v15.4s, v26.4h, v2.4h\n"
-    "smlal2 v18.4s, v26.8h, v2.8h\n"
-    "ldr d26, [x7, x10]\n"
-    "smlal v16.4s, v25.4h, v2.4h\n"
-    "smlal2 v21.4s, v25.8h, v2.8h\n"
-    "smlal v7.4s, v31.4h, v2.4h\n"
-    "smlal2 v17.4s, v31.8h, v2.8h\n"
-    "usubl v30.8h, v30.8b, v9.8b\n"
-    "ssubl v3.8h, v3.8b, v14.8b\n"
-    "usubl v28.8h, v28.8b, v9.8b\n"
-    "smlal v8.4s, v30.4h, v2.4h\n"
-    "smlal2 v5.4s, v30.8h, v2.8h\n"
-    "ldr d2, [x3, #0xb0]\n"
-    "smlal v15.4s, v25.4h, v3.4h\n"
-    "smlal2 v18.4s, v25.8h, v3.8h\n"
-    "ldr d25, [x26, x10]\n"
-    "smlal v16.4s, v24.4h, v3.4h\n"
-    "smlal2 v21.4s, v24.8h, v3.8h\n"
-    "smlal v7.4s, v30.4h, v3.4h\n"
-    "smlal2 v17.4s, v30.8h, v3.8h\n"
-    "smlal v8.4s, v28.4h, v3.4h\n"
-    "smlal2 v5.4s, v28.8h, v3.8h\n"
-    "ldr d3, [x3, #0xb8]\n"
-    "ssubl v4.8h, v4.8b, v14.8b\n"
-    "usubl v26.8h, v26.8b, v9.8b\n"
-    "ssubl v0.8h, v0.8b, v14.8b\n"
-    "smlal v15.4s, v24.4h, v4.4h\n"
-    "smlal2 v18.4s, v24.8h, v4.8h\n"
-    "ldr d24, [x23, x10]\n"
-    "smlal v16.4s, v22.4h, v4.4h\n"
-    "smlal2 v21.4s, v22.8h, v4.8h\n"
-    "smlal v7.4s, v28.4h, v4.4h\n"
-    "smlal2 v17.4s, v28.8h, v4.8h\n"
-    "smlal v8.4s, v26.4h, v4.4h\n"
+    "ldr x11, [x20, #0xe0]\n"
+    "ldr x17, [x20, #0xe8]\n"
+    "smlal2 v16.4s, v29.8h, v0.8h\n"
     "smlal2 v5.4s, v26.8h, v4.8h\n"
-    "ldr d4, [x3, #0xc0]\n"
-    "add x3, x3, #0xc8\n"
-    "smlal v15.4s, v27.4h, v0.4h\n"
-    "smlal2 v18.4s, v27.8h, v0.8h\n"
-    "ldr d27, [x22, x10]\n"
-    "smlal v16.4s, v23.4h, v0.4h\n"
-    "smlal2 v21.4s, v23.8h, v0.8h\n"
-    "usubl v25.8h, v25.8b, v9.8b\n"
-    "usubl v24.8h, v24.8b, v9.8b\n"
+    "ldr d4, [x23, #0x48]\n"
+    "ssubl v4.8h, v4.8b, v14.8b\n"
+    "smlal v17.4s, v28.4h, v0.4h\n"
+    "smlal v10.4s, v22.4h, v0.4h\n"
+    "ldr x5, [x20, #0xf0]\n"
+    "ldr q12, [x10, #0x0]\n"
+    "smlal v6.4s, v25.4h, v0.4h\n"
+    "smlal2 v8.4s, v28.8h, v0.8h\n"
+    "ldr q19, [x1, #0x0]\n"
+    "ldr q20, [x10, #0x10]\n"
+    "smlal2 v7.4s, v22.8h, v0.8h\n"
+    "smlal v15.4s, v28.4h, v1.4h\n"
+    "ldr q29, [x1, #0x10]\n"
+    "subs x3, x3, #0x1\n"
+    "smlal2 v16.4s, v28.8h, v1.8h\n"
+    "smlal2 v5.4s, v25.8h, v0.8h\n"
+    "ldr d28, [x26, x24]\n"
+    "ldr d0, [x23, #0x50]\n"
+    "smlal v17.4s, v23.4h, v1.4h\n"
+    "smlal v10.4s, v25.4h, v1.4h\n"
+    "usubl v28.8h, v28.8b, v9.8b\n"
+    "ldr x25, [x20, #0xf8]\n"
+    "smlal v6.4s, v24.4h, v1.4h\n"
+    "smlal2 v8.4s, v23.8h, v1.8h\n"
+    "ssubl v0.8h, v0.8b, v14.8b\n"
+    "add x10, x10, #0x20\n"
+    "smlal2 v7.4s, v25.8h, v1.8h\n"
+    "smlal v15.4s, v23.4h, v2.4h\n"
+    "add x1, x1, #0x20\n"
+    "smlal2 v16.4s, v23.8h, v2.8h\n"
+    "ldr d23, [x12, x24]\n"
+    "smlal2 v5.4s, v24.8h, v1.8h\n"
+    "usubl v23.8h, v23.8b, v9.8b\n"
+    "smlal v17.4s, v31.4h, v2.4h\n"
+    "smlal v10.4s, v24.4h, v2.4h\n"
+    "ldr d1, [x23, #0x58]\n"
     "ssubl v1.8h, v1.8b, v14.8b\n"
-    "smlal v7.4s, v25.4h, v0.4h\n"
-    "smlal2 v17.4s, v25.8h, v0.8h\n"
-    "ldr d25, [x20, x10]\n"
-    "smlal v8.4s, v24.4h, v0.4h\n"
+    "smlal v6.4s, v27.4h, v2.4h\n"
+    "smlal2 v8.4s, v31.8h, v2.8h\n"
+    "ldr x26, [x20, #0x100]\n"
+    "smlal2 v7.4s, v24.8h, v2.8h\n"
+    "smlal v15.4s, v31.4h, v3.4h\n"
+    "smlal2 v16.4s, v31.8h, v3.8h\n"
+    "smlal2 v5.4s, v27.8h, v2.8h\n"
+    "ldr d31, [x14, x24]\n"
+    "usubl v31.8h, v31.8b, v9.8b\n"
+    "smlal v17.4s, v30.4h, v3.4h\n"
+    "smlal v10.4s, v27.4h, v3.4h\n"
+    "ldr d2, [x23, #0x60]\n"
+    "ssubl v2.8h, v2.8b, v14.8b\n"
+    "smlal v6.4s, v23.4h, v3.4h\n"
+    "smlal2 v8.4s, v30.8h, v3.8h\n"
+    "ldr x12, [x20, #0x108]\n"
+    "smlal2 v7.4s, v27.8h, v3.8h\n"
+    "smlal v15.4s, v30.4h, v4.4h\n"
+    "smlal2 v16.4s, v30.8h, v4.8h\n"
+    "ldr d30, [x15, x24]\n"
+    "smlal2 v5.4s, v23.8h, v3.8h\n"
+    "usubl v30.8h, v30.8b, v9.8b\n"
+    "smlal v17.4s, v26.4h, v4.4h\n"
+    "smlal v10.4s, v23.4h, v4.4h\n"
+    "ldr d3, [x23, #0x68]\n"
+    "ssubl v3.8h, v3.8b, v14.8b\n"
+    "smlal v6.4s, v28.4h, v4.4h\n"
+    "smlal2 v8.4s, v26.8h, v4.8h\n"
+    "ldr d26, [x21, x24]\n"
+    "usubl v26.8h, v26.8b, v9.8b\n"
+    "smlal2 v7.4s, v23.8h, v4.8h\n"
+    "smlal v15.4s, v22.4h, v0.4h\n"
+    "ldr x14, [x20, #0x110]\n"
+    "ldr x21, [x20, #0x118]\n"
+    "smlal2 v16.4s, v22.8h, v0.8h\n"
+    "smlal2 v5.4s, v28.8h, v4.8h\n"
+    "ldr d4, [x23, #0x70]\n"
+    "ldr d22, [x9, x24]\n"
+    "smlal v17.4s, v25.4h, v0.4h\n"
+    "smlal v10.4s, v31.4h, v0.4h\n"
+    "ssubl v4.8h, v4.8b, v14.8b\n"
+    "smlal v6.4s, v30.4h, v0.4h\n"
+    "smlal2 v8.4s, v25.8h, v0.8h\n"
+    "usubl v22.8h, v22.8b, v9.8b\n"
+    "smlal2 v7.4s, v31.8h, v0.8h\n"
+    "smlal v15.4s, v25.4h, v1.4h\n"
+    "smlal2 v16.4s, v25.8h, v1.8h\n"
+    "ldr d25, [x2, x24]\n"
+    "smlal2 v5.4s, v30.8h, v0.8h\n"
+    "usubl v25.8h, v25.8b, v9.8b\n"
+    "smlal v17.4s, v24.4h, v1.4h\n"
+    "smlal v10.4s, v30.4h, v1.4h\n"
+    "ldr d0, [x23, #0x78]\n"
+    "ssubl v0.8h, v0.8b, v14.8b\n"
+    "smlal v6.4s, v26.4h, v1.4h\n"
+    "smlal2 v8.4s, v24.8h, v1.8h\n"
+    "smlal2 v7.4s, v30.8h, v1.8h\n"
+    "smlal v15.4s, v24.4h, v2.4h\n"
+    "smlal2 v16.4s, v24.8h, v2.8h\n"
+    "ldr d24, [x13, x24]\n"
+    "smlal2 v5.4s, v26.8h, v1.8h\n"
+    "usubl v24.8h, v24.8b, v9.8b\n"
+    "smlal v17.4s, v27.4h, v2.4h\n"
+    "smlal v10.4s, v26.4h, v2.4h\n"
+    "ldr d1, [x23, #0x80]\n"
+    "ssubl v1.8h, v1.8b, v14.8b\n"
+    "smlal v6.4s, v25.4h, v2.4h\n"
+    "smlal2 v8.4s, v27.8h, v2.8h\n"
+    "smlal2 v7.4s, v26.8h, v2.8h\n"
+    "smlal v15.4s, v27.4h, v3.4h\n"
+    "smlal2 v16.4s, v27.8h, v3.8h\n"
+    "smlal2 v5.4s, v25.8h, v2.8h\n"
+    "ldr d27, [x19, x24]\n"
+    "usubl v27.8h, v27.8b, v9.8b\n"
+    "smlal v17.4s, v23.4h, v3.4h\n"
+    "smlal v10.4s, v25.4h, v3.4h\n"
+    "ldr d2, [x23, #0x88]\n"
+    "ssubl v2.8h, v2.8b, v14.8b\n"
+    "smlal v6.4s, v24.4h, v3.4h\n"
+    "smlal2 v8.4s, v23.8h, v3.8h\n"
+    "smlal2 v7.4s, v25.8h, v3.8h\n"
+    "smlal v15.4s, v23.4h, v4.4h\n"
+    "smlal2 v16.4s, v23.8h, v4.8h\n"
+    "ldr d23, [x28, x24]\n"
+    "smlal2 v5.4s, v24.8h, v3.8h\n"
+    "usubl v23.8h, v23.8b, v9.8b\n"
+    "smlal v17.4s, v28.4h, v4.4h\n"
+    "smlal v10.4s, v24.4h, v4.4h\n"
+    "ldr d3, [x23, #0x90]\n"
+    "ssubl v3.8h, v3.8b, v14.8b\n"
+    "smlal v6.4s, v22.4h, v4.4h\n"
+    "smlal2 v8.4s, v28.8h, v4.8h\n"
+    "ldr d28, [x11, x24]\n"
+    "usubl v28.8h, v28.8b, v9.8b\n"
+    "smlal2 v7.4s, v24.8h, v4.8h\n"
+    "smlal v15.4s, v31.4h, v0.4h\n"
+    "smlal2 v16.4s, v31.8h, v0.8h\n"
+    "ldr d31, [x6, x24]\n"
+    "smlal2 v5.4s, v22.8h, v4.8h\n"
+    "usubl v31.8h, v31.8b, v9.8b\n"
+    "smlal v17.4s, v30.4h, v0.4h\n"
+    "smlal v10.4s, v27.4h, v0.4h\n"
+    "ldr d4, [x23, #0x98]\n"
+    "ssubl v4.8h, v4.8b, v14.8b\n"
+    "smlal v6.4s, v23.4h, v0.4h\n"
+    "smlal2 v8.4s, v30.8h, v0.8h\n"
+    "smlal2 v7.4s, v27.8h, v0.8h\n"
+    "smlal v15.4s, v30.4h, v1.4h\n"
+    "smlal2 v16.4s, v30.8h, v1.8h\n"
+    "ldr d30, [x27, x24]\n"
+    "smlal2 v5.4s, v23.8h, v0.8h\n"
+    "usubl v30.8h, v30.8b, v9.8b\n"
+    "smlal v17.4s, v26.4h, v1.4h\n"
+    "smlal v10.4s, v23.4h, v1.4h\n"
+    "ldr d0, [x23, #0xa0]\n"
+    "ssubl v0.8h, v0.8b, v14.8b\n"
+    "smlal v6.4s, v31.4h, v1.4h\n"
+    "smlal2 v8.4s, v26.8h, v1.8h\n"
+    "smlal2 v7.4s, v23.8h, v1.8h\n"
+    "smlal v15.4s, v26.4h, v2.4h\n"
+    "smlal2 v16.4s, v26.8h, v2.8h\n"
+    "smlal2 v5.4s, v31.8h, v1.8h\n"
+    "ldr d26, [x17, x24]\n"
+    "usubl v26.8h, v26.8b, v9.8b\n"
+    "smlal v17.4s, v25.4h, v2.4h\n"
+    "smlal v10.4s, v31.4h, v2.4h\n"
+    "ldr d1, [x23, #0xa8]\n"
+    "ssubl v1.8h, v1.8b, v14.8b\n"
+    "smlal v6.4s, v30.4h, v2.4h\n"
+    "smlal2 v8.4s, v25.8h, v2.8h\n"
+    "smlal2 v7.4s, v31.8h, v2.8h\n"
+    "smlal v15.4s, v25.4h, v3.4h\n"
+    "smlal2 v16.4s, v25.8h, v3.8h\n"
+    "smlal2 v5.4s, v30.8h, v2.8h\n"
+    "ldr d25, [x5, x24]\n"
+    "usubl v25.8h, v25.8b, v9.8b\n"
+    "smlal v17.4s, v24.4h, v3.4h\n"
+    "smlal v10.4s, v30.4h, v3.4h\n"
+    "ldr d2, [x23, #0xb0]\n"
+    "ssubl v2.8h, v2.8b, v14.8b\n"
+    "smlal v6.4s, v28.4h, v3.4h\n"
+    "smlal2 v8.4s, v24.8h, v3.8h\n"
+    "smlal2 v7.4s, v30.8h, v3.8h\n"
+    "smlal v15.4s, v24.4h, v4.4h\n"
+    "smlal2 v16.4s, v24.8h, v4.8h\n"
+    "ldr d24, [x25, x24]\n"
+    "smlal2 v5.4s, v28.8h, v3.8h\n"
+    "usubl v24.8h, v24.8b, v9.8b\n"
+    "smlal v17.4s, v22.4h, v4.4h\n"
+    "smlal v10.4s, v28.4h, v4.4h\n"
+    "ldr d3, [x23, #0xb8]\n"
+    "ssubl v3.8h, v3.8b, v14.8b\n"
+    "smlal v6.4s, v26.4h, v4.4h\n"
+    "smlal2 v7.4s, v28.8h, v4.8h\n"
+    "smlal v15.4s, v27.4h, v0.4h\n"
+    "smlal2 v16.4s, v27.8h, v0.8h\n"
+    "ldr d27, [x26, x24]\n"
+    "usubl v27.8h, v27.8b, v9.8b\n"
+    "smlal2 v8.4s, v22.8h, v4.8h\n"
+    "smlal2 v5.4s, v26.8h, v4.8h\n"
+    "ldr d4, [x23, #0xc0]\n"
+    "ssubl v4.8h, v4.8b, v14.8b\n"
+    "smlal v17.4s, v23.4h, v0.4h\n"
+    "smlal v10.4s, v25.4h, v0.4h\n"
+    "add x23, x23, #0xc8\n"
+    "smlal v6.4s, v24.4h, v0.4h\n"
+    "smlal2 v7.4s, v25.8h, v0.8h\n"
+    "ldr d25, [x12, x24]\n"
+    "usubl v25.8h, v25.8b, v9.8b\n"
+    "smlal2 v8.4s, v23.8h, v0.8h\n"
     "smlal2 v5.4s, v24.8h, v0.8h\n"
     "smlal v15.4s, v23.4h, v1.4h\n"
-    "smlal2 v18.4s, v23.8h, v1.8h\n"
-    "smlal v16.4s, v31.4h, v1.4h\n"
-    "smlal2 v21.4s, v31.8h, v1.8h\n"
-    "smlal v7.4s, v24.4h, v1.4h\n"
-    "smlal2 v17.4s, v24.8h, v1.8h\n"
-    "ldr d24, [x13, x10]\n"
-    "usubl v27.8h, v27.8b, v9.8b\n"
-    "ssubl v2.8h, v2.8b, v14.8b\n"
-    "usubl v25.8h, v25.8b, v9.8b\n"
-    "smlal v8.4s, v27.4h, v1.4h\n"
+    "smlal v17.4s, v31.4h, v1.4h\n"
+    "smlal v10.4s, v24.4h, v1.4h\n"
+    "smlal v6.4s, v27.4h, v1.4h\n"
+    "smlal2 v7.4s, v24.8h, v1.8h\n"
+    "ldr d24, [x14, x24]\n"
+    "smlal2 v16.4s, v23.8h, v1.8h\n"
+    "usubl v24.8h, v24.8b, v9.8b\n"
+    "smlal2 v8.4s, v31.8h, v1.8h\n"
     "smlal2 v5.4s, v27.8h, v1.8h\n"
     "smlal v15.4s, v31.4h, v2.4h\n"
-    "smlal2 v18.4s, v31.8h, v2.8h\n"
-    "smlal v16.4s, v30.4h, v2.4h\n"
-    "smlal2 v21.4s, v30.8h, v2.8h\n"
-    "smlal v7.4s, v27.4h, v2.4h\n"
-    "smlal2 v17.4s, v27.8h, v2.8h\n"
-    "ldr d27, [x21, x10]\n"
-    "add x10, x10, #0x8\n"
-    "smlal v8.4s, v25.4h, v2.4h\n"
+    "smlal v17.4s, v30.4h, v2.4h\n"
+    "smlal v10.4s, v27.4h, v2.4h\n"
+    "smlal v6.4s, v25.4h, v2.4h\n"
+    "smlal2 v7.4s, v27.8h, v2.8h\n"
+    "ldr d27, [x21, x24]\n"
+    "smlal2 v16.4s, v31.8h, v2.8h\n"
+    "usubl v27.8h, v27.8b, v9.8b\n"
+    "smlal2 v8.4s, v30.8h, v2.8h\n"
     "smlal2 v5.4s, v25.8h, v2.8h\n"
-    "ssubl v3.8h, v3.8b, v14.8b\n"
-    "usubl v24.8h, v24.8b, v9.8b\n"
-    "ssubl v4.8h, v4.8b, v14.8b\n"
+    "add x24, x24, #0x8\n"
     "smlal v15.4s, v30.4h, v3.4h\n"
-    "smlal2 v18.4s, v30.8h, v3.8h\n"
-    "smlal v16.4s, v28.4h, v3.4h\n"
-    "smlal2 v21.4s, v28.8h, v3.8h\n"
-    "smlal v7.4s, v25.4h, v3.4h\n"
-    "smlal2 v17.4s, v25.8h, v3.8h\n"
-    "smlal v8.4s, v24.4h, v3.4h\n"
+    "smlal v17.4s, v28.4h, v3.4h\n"
+    "smlal v10.4s, v25.4h, v3.4h\n"
+    "smlal v6.4s, v24.4h, v3.4h\n"
+    "smlal2 v16.4s, v30.8h, v3.8h\n"
+    "smlal2 v8.4s, v28.8h, v3.8h\n"
+    "smlal2 v7.4s, v25.8h, v3.8h\n"
     "smlal2 v5.4s, v24.8h, v3.8h\n"
     "smlal v15.4s, v28.4h, v4.4h\n"
-    "smlal2 v18.4s, v28.8h, v4.8h\n"
-    "smlal v16.4s, v26.4h, v4.4h\n"
-    "smlal2 v21.4s, v26.8h, v4.8h\n"
-    "smlal v7.4s, v24.4h, v4.4h\n"
-    "smlal2 v17.4s, v24.8h, v4.8h\n"
-    "usubl v27.8h, v27.8b, v9.8b\n"
-    "sqrdmulh v15.4s, v15.4s, v6.4s\n"
-    "sqrdmulh v18.4s, v18.4s, v20.4s\n"
-    "smlal v8.4s, v27.4h, v4.4h\n"
+    "smlal v17.4s, v26.4h, v4.4h\n"
+    "sqdmulh v15.4s, v15.4s, v12.4s\n"
+    "smlal v10.4s, v24.4h, v4.4h\n"
+    "smlal v6.4s, v27.4h, v4.4h\n"
+    "sqdmulh v17.4s, v17.4s, v12.4s\n"
+    "smlal2 v16.4s, v28.8h, v4.8h\n"
+    "smlal2 v8.4s, v26.8h, v4.8h\n"
+    "sqdmulh v10.4s, v10.4s, v12.4s\n"
+    "smlal2 v7.4s, v24.8h, v4.8h\n"
     "smlal2 v5.4s, v27.8h, v4.8h\n"
-    "and v28.16b, v15.16b, v19.16b\n"
-    "and v26.16b, v18.16b, v12.16b\n"
-    "sqrdmulh v16.4s, v16.4s, v6.4s\n"
-    "sshr v28.4s, v28.4s, #0x1f\n"
+    "sqdmulh v6.4s, v6.4s, v12.4s\n"
+    "and v23.16b, v15.16b, v19.16b\n"
+    "sqdmulh v16.4s, v16.4s, v20.4s\n"
+    "and v22.16b, v17.16b, v19.16b\n"
+    "sqdmulh v8.4s, v8.4s, v20.4s\n"
+    "and v21.16b, v10.16b, v19.16b\n"
+    "sqdmulh v7.4s, v7.4s, v20.4s\n"
+    "and v26.16b, v6.16b, v19.16b\n"
+    "sqdmulh v5.4s, v5.4s, v20.4s\n"
+    "sshr v23.4s, v23.4s, #0x1f\n"
+    "and v4.16b, v16.16b, v29.16b\n"
+    "sshr v22.4s, v22.4s, #0x1f\n"
+    "and v2.16b, v8.16b, v29.16b\n"
+    "sshr v21.4s, v21.4s, #0x1f\n"
+    "and v3.16b, v7.16b, v29.16b\n"
     "sshr v26.4s, v26.4s, #0x1f\n"
-    "sqrdmulh v21.4s, v21.4s, v20.4s\n"
-    "sqadd v15.4s, v15.4s, v28.4s\n"
-    "sqadd v18.4s, v18.4s, v26.4s\n"
-    "and v29.16b, v16.16b, v19.16b\n"
-    "and v4.16b, v21.16b, v12.16b\n"
-    "srshl v15.4s, v15.4s, v19.4s\n"
-    "srshl v18.4s, v18.4s, v12.4s\n"
-    "sshr v29.4s, v29.4s, #0x1f\n"
+    "and v25.16b, v5.16b, v29.16b\n"
+    "sqadd v15.4s, v15.4s, v23.4s\n"
     "sshr v4.4s, v4.4s, #0x1f\n"
-    "add v15.4s, v15.4s, v10.4s\n"
-    "add v18.4s, v18.4s, v10.4s\n"
-    "sqadd v16.4s, v16.4s, v29.4s\n"
-    "smin v15.4s, v15.4s, v13.4s\n"
-    "smin v18.4s, v18.4s, v13.4s\n"
-    "sqadd v21.4s, v21.4s, v4.4s\n"
-    "smax v15.4s, v15.4s, v11.4s\n"
-    "smax v18.4s, v18.4s, v11.4s\n"
-    "srshl v16.4s, v16.4s, v19.4s\n"
-    "srshl v21.4s, v21.4s, v12.4s\n"
-    "uzp1 v15.16b, v15.16b, v18.16b\n"
-    "sqrdmulh v7.4s, v7.4s, v6.4s\n"
-    "uzp1 v15.16b, v15.16b, v15.16b\n"
-    "str d15, [x17, x1]\n"
-    "add v16.4s, v16.4s, v10.4s\n"
-    "add v21.4s, v21.4s, v10.4s\n"
-    "and v25.16b, v7.16b, v19.16b\n"
-    "sqrdmulh v17.4s, v17.4s, v20.4s\n"
-    "smin v16.4s, v16.4s, v13.4s\n"
-    "smin v21.4s, v21.4s, v13.4s\n"
+    "sqadd v17.4s, v17.4s, v22.4s\n"
+    "sshr v2.4s, v2.4s, #0x1f\n"
+    "sqadd v10.4s, v10.4s, v21.4s\n"
+    "sshr v3.4s, v3.4s, #0x1f\n"
+    "sqadd v6.4s, v6.4s, v26.4s\n"
     "sshr v25.4s, v25.4s, #0x1f\n"
-    "smax v16.4s, v16.4s, v11.4s\n"
-    "smax v21.4s, v21.4s, v11.4s\n"
-    "sqadd v7.4s, v7.4s, v25.4s\n"
-    "and v31.16b, v17.16b, v12.16b\n"
-    "uzp1 v16.16b, v16.16b, v21.16b\n"
-    "sqrdmulh v8.4s, v8.4s, v6.4s\n"
-    "uzp1 v16.16b, v16.16b, v16.16b\n"
-    "str d16, [x16, x1]\n"
-    "srshl v7.4s, v7.4s, v19.4s\n"
-    "sshr v31.4s, v31.4s, #0x1f\n"
-    "and v24.16b, v8.16b, v19.16b\n"
-    "sqrdmulh v5.4s, v5.4s, v20.4s\n"
-    "sqadd v17.4s, v17.4s, v31.4s\n"
-    "add v7.4s, v7.4s, v10.4s\n"
-    "sshr v24.4s, v24.4s, #0x1f\n"
-    "and v1.16b, v5.16b, v12.16b\n"
-    "smin v7.4s, v7.4s, v13.4s\n"
-    "srshl v17.4s, v17.4s, v12.4s\n"
-    "sqadd v8.4s, v8.4s, v24.4s\n"
-    "smax v7.4s, v7.4s, v11.4s\n"
-    "sshr v1.4s, v1.4s, #0x1f\n"
-    "add v17.4s, v17.4s, v10.4s\n"
-    "srshl v8.4s, v8.4s, v19.4s\n"
-    "sqadd v5.4s, v5.4s, v1.4s\n"
-    "smin v17.4s, v17.4s, v13.4s\n"
-    "add v8.4s, v8.4s, v10.4s\n"
-    "smax v17.4s, v17.4s, v11.4s\n"
-    "srshl v5.4s, v5.4s, v12.4s\n"
-    "smin v8.4s, v8.4s, v13.4s\n"
-    "uzp1 v7.16b, v7.16b, v17.16b\n"
-    "add v5.4s, v5.4s, v10.4s\n"
-    "uzp1 v7.16b, v7.16b, v7.16b\n"
-    "str d7, [x6, x1]\n"
-    "smax v8.4s, v8.4s, v11.4s\n"
-    "smin v5.4s, v5.4s, v13.4s\n"
-    "smax v5.4s, v5.4s, v11.4s\n"
-    "uzp1 v8.16b, v8.16b, v5.16b\n"
-    "uzp1 v8.16b, v8.16b, v8.16b\n"
-    "str d8, [x8, x1]\n"
-    "add x1, x1, #0x8\n"
-    "ldr x12, [%x[params], %[offsetof_Params_bias]]\n"
-    "ldr q15, [x12, #0x0]\n"
-    "mov v16.16b, v15.16b\n"
-    "ldr q18, [x12, #0x10]\n"
-    "add x12, x12, #0x20\n"
-    "mov v7.16b, v15.16b\n"
-    "str x12, [%x[params], %[offsetof_Params_bias]]\n"
-    "mov v8.16b, v15.16b\n"
-    "ldr d0, [x3, #0x0]\n"
-    "ldr d1, [x3, #0x8]\n"
-    "mov v21.16b, v18.16b\n"
-    "ldr d2, [x3, #0x10]\n"
-    "mov v17.16b, v18.16b\n"
-    "ldr d3, [x3, #0x18]\n"
-    "mov v5.16b, v18.16b\n"
-    "ldr d4, [x3, #0x20]\n"
+    "srshl v15.4s, v15.4s, v19.4s\n"
+    "sqadd v16.4s, v16.4s, v4.4s\n"
+    "srshl v17.4s, v17.4s, v19.4s\n"
+    "sqadd v8.4s, v8.4s, v2.4s\n"
+    "srshl v10.4s, v10.4s, v19.4s\n"
+    "sqadd v7.4s, v7.4s, v3.4s\n"
+    "srshl v6.4s, v6.4s, v19.4s\n"
+    "sqadd v5.4s, v5.4s, v25.4s\n"
+    "srshl v16.4s, v16.4s, v29.4s\n"
+    "sqxtn v15.4h, v15.4s\n"
+    "srshl v8.4s, v8.4s, v29.4s\n"
+    "sqxtn v17.4h, v17.4s\n"
+    "srshl v7.4s, v7.4s, v29.4s\n"
+    "sqxtn v10.4h, v10.4s\n"
+    "srshl v5.4s, v5.4s, v29.4s\n"
+    "sqxtn v6.4h, v6.4s\n"
+    "sqxtn2 v15.8h, v16.4s\n"
+    "sqxtn2 v17.8h, v8.4s\n"
+    "sqxtn2 v10.8h, v7.4s\n"
+    "sqxtn2 v6.8h, v5.4s\n"
+    "sqadd v15.8h, v15.8h, v18.8h\n"
+    "sqadd v17.8h, v17.8h, v18.8h\n"
+    "sqadd v10.8h, v10.8h, v18.8h\n"
+    "sqadd v6.8h, v6.8h, v18.8h\n"
+    "smax v15.8h, v15.8h, v11.8h\n"
+    "smax v17.8h, v17.8h, v11.8h\n"
+    "smax v10.8h, v10.8h, v11.8h\n"
+    "smax v6.8h, v6.8h, v11.8h\n"
+    "smin v15.8h, v15.8h, v13.8h\n"
+    "smin v17.8h, v17.8h, v13.8h\n"
+    "smin v10.8h, v10.8h, v13.8h\n"
+    "smin v6.8h, v6.8h, v13.8h\n"
+    "uzp1 v15.16b, v15.16b, v15.16b\n"
+    "uzp1 v17.16b, v17.16b, v17.16b\n"
+    "str d15, [x16, x22]\n"
+    "uzp1 v10.16b, v10.16b, v10.16b\n"
+    "uzp1 v6.16b, v6.16b, v6.16b\n"
+    "str d17, [x8, x22]\n"
+    "str d10, [x4, x22]\n"
+    "str d6, [x7, x22]\n"
+    "ldr x19, [%x[params], %[offsetof_Params_bias]]\n"
+    "ldr q15, [x19, #0x0]\n"
+    "add x22, x22, #0x8\n"
+    "ldr q16, [x19, #0x10]\n"
+    "add x19, x19, #0x20\n"
+    "str x19, [%x[params], %[offsetof_Params_bias]]\n"
+    "ldr d0, [x23, #0x0]\n"
+    "ldr d1, [x23, #0x8]\n"
+    "ldr d2, [x23, #0x10]\n"
+    "mov v17.16b, v15.16b\n"
+    "mov v8.16b, v16.16b\n"
+    "ldr d3, [x23, #0x18]\n"
+    "ldr d4, [x23, #0x20]\n"
+    "mov v10.16b, v15.16b\n"
+    "mov v7.16b, v16.16b\n"
+    "ldp x28, x6, [x20, #0x0]\n"
+    "ldp x26, x25, [x20, #0x10]\n"
+    "mov v6.16b, v15.16b\n"
+    "mov v5.16b, v16.16b\n"
+    "ldp x5, x2, [x20, #0x20]\n"
+    "ldp x27, x21, [x20, #0x30]\n"
     "ssubl v0.8h, v0.8b, v14.8b\n"
-    "ldp x28, x27, [x25, #0x0]\n"
     "ssubl v1.8h, v1.8b, v14.8b\n"
-    "ldp x26, x13, [x25, #0x10]\n"
+    "ldp x12, x19, [x20, #0x40]\n"
+    "ldr d31, [x28, x24]\n"
     "ssubl v2.8h, v2.8b, v14.8b\n"
     "ssubl v3.8h, v3.8b, v14.8b\n"
-    "ldp x24, x23, [x25, #0x20]\n"
+    "ldr d30, [x6, x24]\n"
+    "ldr d29, [x26, x24]\n"
     "ssubl v4.8h, v4.8b, v14.8b\n"
-    "ldp x22, x21, [x25, #0x30]\n"
-    "ldp x20, x0, [x25, #0x40]\n"
-    "ldr d31, [x28, x10]\n"
     "usubl v31.8h, v31.8b, v9.8b\n"
-    "ldr d30, [x27, x10]\n"
-    "ldr d29, [x26, x10]\n"
+    "ldr d28, [x25, x24]\n"
+    "ldr d27, [x5, x24]\n"
     "usubl v30.8h, v30.8b, v9.8b\n"
-    "ldr d28, [x13, x10]\n"
     "usubl v29.8h, v29.8b, v9.8b\n"
-    "ldr d27, [x24, x10]\n"
-    "ldr d23, [x23, x10]\n"
+    "ldr d23, [x2, x24]\n"
+    "ldr d25, [x27, x24]\n"
     "usubl v28.8h, v28.8b, v9.8b\n"
-    "ldr d25, [x22, x10]\n"
-    "ldr d24, [x21, x10]\n"
     "usubl v27.8h, v27.8b, v9.8b\n"
+    "ldr d24, [x21, x24]\n"
+    "ldr d26, [x12, x24]\n"
     "usubl v23.8h, v23.8b, v9.8b\n"
-    "ldr d26, [x20, x10]\n"
-    "ldr d22, [x0, x10]\n"
     "usubl v25.8h, v25.8b, v9.8b\n"
+    "ldr d22, [x19, x24]\n"
     "usubl v24.8h, v24.8b, v9.8b\n"
     "usubl v26.8h, v26.8b, v9.8b\n"
     "usubl v22.8h, v22.8b, v9.8b\n"
     "bgt 1b\n"
     "2:"  // Tail
     "smlal v15.4s, v31.4h, v0.4h\n"
-    "ldr x20, [x25, #0x50]\n"
-    "tst x4, #0x7\n"
-    "smlal2 v18.4s, v31.8h, v0.8h\n"
-    "ldr x28, [x25, #0x58]\n"
-    "smlal v16.4s, v30.4h, v0.4h\n"
-    "ldr x0, [x25, #0x60]\n"
-    "smlal2 v21.4s, v30.8h, v0.8h\n"
-    "ldr d31, [x20, x10]\n"
-    "smlal v7.4s, v29.4h, v0.4h\n"
-    "ldr x7, [x25, #0x68]\n"
-    "smlal2 v17.4s, v29.8h, v0.8h\n"
-    "ldr x26, [x25, #0x70]\n"
-    "smlal v8.4s, v28.4h, v0.4h\n"
-    "ldr x23, [x25, #0x78]\n"
+    "smlal2 v16.4s, v31.8h, v0.8h\n"
+    "ldr x19, [x20, #0x50]\n"
+    "ldr d31, [x19, x24]\n"
+    "smlal v17.4s, v30.4h, v0.4h\n"
+    "smlal v10.4s, v29.4h, v0.4h\n"
+    "ldr x15, [x20, #0x58]\n"
+    "usubl v31.8h, v31.8b, v9.8b\n"
+    "smlal v6.4s, v28.4h, v0.4h\n"
+    "smlal2 v8.4s, v30.8h, v0.8h\n"
+    "ldr x19, [x20, #0x60]\n"
+    "ldr x27, [x20, #0x68]\n"
+    "smlal2 v7.4s, v29.8h, v0.8h\n"
+    "smlal v15.4s, v30.4h, v1.4h\n"
+    "ldr x5, [x20, #0x70]\n"
+    "ldr x11, [x20, #0x78]\n"
+    "smlal2 v16.4s, v30.8h, v1.8h\n"
     "smlal2 v5.4s, v28.8h, v0.8h\n"
-    "ldr d0, [x3, #0x28]\n"
-    "smlal v15.4s, v30.4h, v1.4h\n"
-    "ldr x20, [x25, #0x80]\n"
-    "smlal2 v18.4s, v30.8h, v1.8h\n"
-    "ldr d30, [x28, x10]\n"
-    "smlal v16.4s, v27.4h, v1.4h\n"
-    "ldr x22, [x25, #0x88]\n"
-    "smlal2 v21.4s, v27.8h, v1.8h\n"
-    "ldr x13, [x25, #0x90]\n"
-    "smlal v7.4s, v28.4h, v1.4h\n"
-    "ldr x21, [x25, #0x98]\n"
-    "smlal2 v17.4s, v28.8h, v1.8h\n"
-    "ldr x14, [x25, #0xa0]\n"
-    "smlal v8.4s, v23.4h, v1.4h\n"
-    "ldr x11, [x25, #0xa8]\n"
-    "smlal2 v5.4s, v23.8h, v1.8h\n"
-    "ldr d1, [x3, #0x30]\n"
+    "ldr d30, [x15, x24]\n"
+    "usubl v30.8h, v30.8b, v9.8b\n"
+    "smlal v17.4s, v27.4h, v1.4h\n"
+    "smlal v10.4s, v28.4h, v1.4h\n"
+    "ldr d0, [x23, #0x28]\n"
+    "ssubl v0.8h, v0.8b, v14.8b\n"
+    "smlal v6.4s, v23.4h, v1.4h\n"
+    "smlal2 v8.4s, v27.8h, v1.8h\n"
+    "ldr x12, [x20, #0x80]\n"
+    "ldr x26, [x20, #0x88]\n"
+    "smlal2 v7.4s, v28.8h, v1.8h\n"
     "smlal v15.4s, v27.4h, v2.4h\n"
-    "ldr x24, [x25, #0xb0]\n"
-    "smlal2 v18.4s, v27.8h, v2.8h\n"
-    "ldr d27, [x0, x10]\n"
-    "smlal v16.4s, v25.4h, v2.4h\n"
-    "ldr x0, [x25, #0xb8]\n"
-    "smlal2 v21.4s, v25.8h, v2.8h\n"
-    "ldr x15, [x25, #0xc0]\n"
-    "smlal v7.4s, v23.4h, v2.4h\n"
-    "ldr x9, [x25, #0xc8]\n"
-    "smlal2 v17.4s, v23.8h, v2.8h\n"
-    "ldr x27, [x25, #0xd0]\n"
-    "usubl v31.8h, v31.8b, v9.8b\n"
-    "ldr x28, [x25, #0xd8]\n"
+    "ldr x14, [x20, #0x90]\n"
+    "ldr x15, [x20, #0x98]\n"
+    "smlal2 v16.4s, v27.8h, v2.8h\n"
+    "smlal2 v5.4s, v23.8h, v1.8h\n"
+    "ldr d27, [x19, x24]\n"
+    "usubl v27.8h, v27.8b, v9.8b\n"
+    "smlal v17.4s, v25.4h, v2.4h\n"
+    "smlal v10.4s, v23.4h, v2.4h\n"
+    "ldr d1, [x23, #0x30]\n"
+    "ssubl v1.8h, v1.8b, v14.8b\n"
+    "smlal v6.4s, v31.4h, v2.4h\n"
+    "smlal2 v8.4s, v25.8h, v2.8h\n"
+    "ldr x21, [x20, #0xa0]\n"
+    "ldr x2, [x20, #0xa8]\n"
+    "smlal2 v7.4s, v23.8h, v2.8h\n"
     "smlal v15.4s, v25.4h, v3.4h\n"
-    "ldr x12, [x25, #0xe0]\n"
-    "smlal2 v18.4s, v25.8h, v3.8h\n"
-    "ldr d25, [x7, x10]\n"
-    "smlal v8.4s, v31.4h, v2.4h\n"
-    "ldr x7, [x25, #0xe8]\n"
+    "ldr x13, [x20, #0xb0]\n"
+    "ldr x9, [x20, #0xb8]\n"
+    "smlal2 v16.4s, v25.8h, v3.8h\n"
     "smlal2 v5.4s, v31.8h, v2.8h\n"
-    "ldr d2, [x3, #0x38]\n"
-    "smlal v16.4s, v24.4h, v3.4h\n"
-    "ldr q6, [x2, #0x0]\n"
-    "smlal2 v21.4s, v24.8h, v3.8h\n"
-    "ldr q19, [x5, #0x0]\n"
-    "smlal v7.4s, v31.4h, v3.4h\n"
-    "ldr q20, [x2, #0x10]\n"
-    "add x2, x2, #0x20\n"
-    "smlal2 v17.4s, v31.8h, v3.8h\n"
-    "ldr q12, [x5, #0x10]\n"
-    "add x5, x5, #0x20\n"
-    "usubl v30.8h, v30.8b, v9.8b\n"
+    "ldr d25, [x27, x24]\n"
+    "usubl v25.8h, v25.8b, v9.8b\n"
+    "smlal v17.4s, v24.4h, v3.4h\n"
+    "smlal v10.4s, v31.4h, v3.4h\n"
+    "ldr d2, [x23, #0x38]\n"
+    "ssubl v2.8h, v2.8b, v14.8b\n"
+    "smlal v6.4s, v30.4h, v3.4h\n"
+    "smlal2 v8.4s, v24.8h, v3.8h\n"
+    "ldr x19, [x20, #0xc0]\n"
+    "ldr x28, [x20, #0xc8]\n"
+    "smlal2 v7.4s, v31.8h, v3.8h\n"
     "smlal v15.4s, v24.4h, v4.4h\n"
-    "smlal2 v18.4s, v24.8h, v4.8h\n"
-    "ldr d24, [x26, x10]\n"
-    "usubl v27.8h, v27.8b, v9.8b\n"
-    "ldr x26, [x25, #0xf0]\n"
-    "smlal v8.4s, v30.4h, v3.4h\n"
+    "ldr x6, [x20, #0xd0]\n"
+    "ldr x27, [x20, #0xd8]\n"
+    "smlal2 v16.4s, v24.8h, v4.8h\n"
     "smlal2 v5.4s, v30.8h, v3.8h\n"
-    "ldr d3, [x3, #0x40]\n"
-    "smlal v16.4s, v27.4h, v4.4h\n"
-    "smlal2 v21.4s, v27.8h, v4.8h\n"
-    "ldr d27, [x23, x10]\n"
-    "smlal v7.4s, v30.4h, v4.4h\n"
-    "ldr x23, [x25, #0xf8]\n"
-    "smlal2 v17.4s, v30.8h, v4.8h\n"
-    "smlal v8.4s, v26.4h, v4.4h\n"
-    "smlal2 v5.4s, v26.8h, v4.8h\n"
-    "ldr d4, [x3, #0x48]\n"
-    "ssubl v0.8h, v0.8b, v14.8b\n"
-    "usubl v25.8h, v25.8b, v9.8b\n"
-    "ssubl v1.8h, v1.8b, v14.8b\n"
+    "ldr d24, [x5, x24]\n"
+    "usubl v24.8h, v24.8b, v9.8b\n"
+    "smlal v17.4s, v27.4h, v4.4h\n"
+    "smlal v10.4s, v30.4h, v4.4h\n"
+    "ldr d3, [x23, #0x40]\n"
+    "ssubl v3.8h, v3.8b, v14.8b\n"
+    "smlal v6.4s, v26.4h, v4.4h\n"
+    "smlal2 v8.4s, v27.8h, v4.8h\n"
+    "ldr d27, [x11, x24]\n"
+    "usubl v27.8h, v27.8b, v9.8b\n"
+    "smlal2 v7.4s, v30.8h, v4.8h\n"
     "smlal v15.4s, v29.4h, v0.4h\n"
-    "smlal2 v18.4s, v29.8h, v0.8h\n"
-    "smlal v16.4s, v28.4h, v0.4h\n"
-    "smlal2 v21.4s, v28.8h, v0.8h\n"
-    "smlal v7.4s, v22.4h, v0.4h\n"
-    "smlal2 v17.4s, v22.8h, v0.8h\n"
-    "smlal v8.4s, v25.4h, v0.4h\n"
-    "smlal2 v5.4s, v25.8h, v0.8h\n"
-    "ldr d0, [x3, #0x50]\n"
-    "smlal v15.4s, v28.4h, v1.4h\n"
-    "smlal2 v18.4s, v28.8h, v1.8h\n"
-    "ldr d28, [x22, x10]\n"
-    "smlal v16.4s, v23.4h, v1.4h\n"
-    "ldr x22, [x25, #0x100]\n"
-    "smlal2 v21.4s, v23.8h, v1.8h\n"
-    "smlal v7.4s, v25.4h, v1.4h\n"
-    "smlal2 v17.4s, v25.8h, v1.8h\n"
-    "usubl v24.8h, v24.8b, v9.8b\n"
-    "ssubl v2.8h, v2.8b, v14.8b\n"
-    "usubl v27.8h, v27.8b, v9.8b\n"
-    "smlal v8.4s, v24.4h, v1.4h\n"
-    "smlal2 v5.4s, v24.8h, v1.8h\n"
-    "ldr d1, [x3, #0x58]\n"
-    "smlal v15.4s, v23.4h, v2.4h\n"
-    "smlal2 v18.4s, v23.8h, v2.8h\n"
-    "ldr d23, [x20, x10]\n"
-    "smlal v16.4s, v31.4h, v2.4h\n"
-    "ldr x20, [x25, #0x108]\n"
-    "smlal2 v21.4s, v31.8h, v2.8h\n"
-    "smlal v7.4s, v24.4h, v2.4h\n"
-    "smlal2 v17.4s, v24.8h, v2.8h\n"
-    "smlal v8.4s, v27.4h, v2.4h\n"
-    "smlal2 v5.4s, v27.8h, v2.8h\n"
-    "ldr d2, [x3, #0x60]\n"
-    "ssubl v3.8h, v3.8b, v14.8b\n"
-    "usubl v23.8h, v23.8b, v9.8b\n"
-    "ssubl v4.8h, v4.8b, v14.8b\n"
-    "smlal v15.4s, v31.4h, v3.4h\n"
-    "smlal2 v18.4s, v31.8h, v3.8h\n"
-    "ldr d31, [x13, x10]\n"
-    "smlal v16.4s, v30.4h, v3.4h\n"
-    "ldr x13, [x25, #0x110]\n"
-    "smlal2 v21.4s, v30.8h, v3.8h\n"
-    "smlal v7.4s, v27.4h, v3.4h\n"
-    "smlal2 v17.4s, v27.8h, v3.8h\n"
-    "smlal v8.4s, v23.4h, v3.4h\n"
-    "smlal2 v5.4s, v23.8h, v3.8h\n"
-    "ldr d3, [x3, #0x68]\n"
-    "smlal v15.4s, v30.4h, v4.4h\n"
-    "smlal2 v18.4s, v30.8h, v4.8h\n"
-    "ldr d30, [x21, x10]\n"
-    "smlal v16.4s, v26.4h, v4.4h\n"
-    "ldr x21, [x25, #0x118]\n"
-    "smlal2 v21.4s, v26.8h, v4.8h\n"
-    "ldr d26, [x14, x10]\n"
-    "smlal v7.4s, v23.4h, v4.4h\n"
-    "smlal2 v17.4s, v23.8h, v4.8h\n"
-    "usubl v28.8h, v28.8b, v9.8b\n"
-    "ssubl v0.8h, v0.8b, v14.8b\n"
-    "usubl v31.8h, v31.8b, v9.8b\n"
-    "smlal v8.4s, v28.4h, v4.4h\n"
-    "smlal2 v5.4s, v28.8h, v4.8h\n"
-    "ldr d4, [x3, #0x70]\n"
-    "smlal v15.4s, v22.4h, v0.4h\n"
-    "smlal2 v18.4s, v22.8h, v0.8h\n"
-    "ldr d22, [x0, x10]\n"
-    "smlal v16.4s, v25.4h, v0.4h\n"
-    "smlal2 v21.4s, v25.8h, v0.8h\n"
-    "smlal v7.4s, v31.4h, v0.4h\n"
-    "smlal2 v17.4s, v31.8h, v0.8h\n"
-    "usubl v30.8h, v30.8b, v9.8b\n"
-    "ssubl v1.8h, v1.8b, v14.8b\n"
-    "usubl v26.8h, v26.8b, v9.8b\n"
-    "smlal v8.4s, v30.4h, v0.4h\n"
-    "smlal2 v5.4s, v30.8h, v0.8h\n"
-    "ldr d0, [x3, #0x78]\n"
-    "smlal v15.4s, v25.4h, v1.4h\n"
-    "smlal2 v18.4s, v25.8h, v1.8h\n"
-    "ldr d25, [x11, x10]\n"
-    "smlal v16.4s, v24.4h, v1.4h\n"
-    "smlal2 v21.4s, v24.8h, v1.8h\n"
-    "smlal v7.4s, v30.4h, v1.4h\n"
-    "smlal2 v17.4s, v30.8h, v1.8h\n"
-    "smlal v8.4s, v26.4h, v1.4h\n"
-    "smlal2 v5.4s, v26.8h, v1.8h\n"
-    "ldr d1, [x3, #0x80]\n"
-    "ssubl v2.8h, v2.8b, v14.8b\n"
-    "usubl v25.8h, v25.8b, v9.8b\n"
-    "ssubl v3.8h, v3.8b, v14.8b\n"
-    "smlal v15.4s, v24.4h, v2.4h\n"
-    "smlal2 v18.4s, v24.8h, v2.8h\n"
-    "ldr d24, [x24, x10]\n"
-    "smlal v16.4s, v27.4h, v2.4h\n"
-    "smlal2 v21.4s, v27.8h, v2.8h\n"
-    "smlal v7.4s, v26.4h, v2.4h\n"
-    "smlal2 v17.4s, v26.8h, v2.8h\n"
-    "smlal v8.4s, v25.4h, v2.4h\n"
-    "smlal2 v5.4s, v25.8h, v2.8h\n"
-    "ldr d2, [x3, #0x88]\n"
-    "smlal v15.4s, v27.4h, v3.4h\n"
-    "smlal2 v18.4s, v27.8h, v3.8h\n"
-    "ldr d27, [x15, x10]\n"
-    "smlal v16.4s, v23.4h, v3.4h\n"
-    "smlal2 v21.4s, v23.8h, v3.8h\n"
-    "smlal v7.4s, v25.4h, v3.4h\n"
-    "smlal2 v17.4s, v25.8h, v3.8h\n"
-    "usubl v24.8h, v24.8b, v9.8b\n"
-    "ssubl v4.8h, v4.8b, v14.8b\n"
-    "usubl v22.8h, v22.8b, v9.8b\n"
-    "smlal v8.4s, v24.4h, v3.4h\n"
-    "smlal2 v5.4s, v24.8h, v3.8h\n"
-    "ldr d3, [x3, #0x90]\n"
-    "smlal v15.4s, v23.4h, v4.4h\n"
-    "smlal2 v18.4s, v23.8h, v4.8h\n"
-    "ldr d23, [x9, x10]\n"
-    "smlal v16.4s, v28.4h, v4.4h\n"
-    "smlal2 v21.4s, v28.8h, v4.8h\n"
-    "ldr d28, [x12, x10]\n"
-    "smlal v7.4s, v24.4h, v4.4h\n"
-    "smlal2 v17.4s, v24.8h, v4.8h\n"
-    "smlal v8.4s, v22.4h, v4.4h\n"
-    "smlal2 v5.4s, v22.8h, v4.8h\n"
-    "ldr d4, [x3, #0x98]\n"
-    "ssubl v0.8h, v0.8b, v14.8b\n"
-    "usubl v27.8h, v27.8b, v9.8b\n"
-    "usubl v23.8h, v23.8b, v9.8b\n"
-    "smlal v15.4s, v31.4h, v0.4h\n"
-    "smlal2 v18.4s, v31.8h, v0.8h\n"
-    "ldr d31, [x27, x10]\n"
-    "smlal v16.4s, v30.4h, v0.4h\n"
-    "smlal2 v21.4s, v30.8h, v0.8h\n"
-    "smlal v7.4s, v27.4h, v0.4h\n"
-    "smlal2 v17.4s, v27.8h, v0.8h\n"
-    "smlal v8.4s, v23.4h, v0.4h\n"
-    "smlal2 v5.4s, v23.8h, v0.8h\n"
-    "ldr d0, [x3, #0xa0]\n"
-    "ssubl v1.8h, v1.8b, v14.8b\n"
-    "usubl v31.8h, v31.8b, v9.8b\n"
-    "ssubl v2.8h, v2.8b, v14.8b\n"
-    "smlal v15.4s, v30.4h, v1.4h\n"
-    "smlal2 v18.4s, v30.8h, v1.8h\n"
-    "ldr d30, [x28, x10]\n"
-    "smlal v16.4s, v26.4h, v1.4h\n"
-    "smlal2 v21.4s, v26.8h, v1.8h\n"
-    "smlal v7.4s, v23.4h, v1.4h\n"
-    "smlal2 v17.4s, v23.8h, v1.8h\n"
-    "smlal v8.4s, v31.4h, v1.4h\n"
-    "smlal2 v5.4s, v31.8h, v1.8h\n"
-    "ldr d1, [x3, #0xa8]\n"
-    "smlal v15.4s, v26.4h, v2.4h\n"
-    "smlal2 v18.4s, v26.8h, v2.8h\n"
-    "ldr d26, [x7, x10]\n"
-    "smlal v16.4s, v25.4h, v2.4h\n"
-    "smlal2 v21.4s, v25.8h, v2.8h\n"
-    "smlal v7.4s, v31.4h, v2.4h\n"
-    "smlal2 v17.4s, v31.8h, v2.8h\n"
-    "usubl v30.8h, v30.8b, v9.8b\n"
-    "ssubl v3.8h, v3.8b, v14.8b\n"
-    "usubl v28.8h, v28.8b, v9.8b\n"
-    "smlal v8.4s, v30.4h, v2.4h\n"
-    "smlal2 v5.4s, v30.8h, v2.8h\n"
-    "ldr d2, [x3, #0xb0]\n"
-    "smlal v15.4s, v25.4h, v3.4h\n"
-    "smlal2 v18.4s, v25.8h, v3.8h\n"
-    "ldr d25, [x26, x10]\n"
-    "smlal v16.4s, v24.4h, v3.4h\n"
-    "smlal2 v21.4s, v24.8h, v3.8h\n"
-    "smlal v7.4s, v30.4h, v3.4h\n"
-    "smlal2 v17.4s, v30.8h, v3.8h\n"
-    "smlal v8.4s, v28.4h, v3.4h\n"
-    "smlal2 v5.4s, v28.8h, v3.8h\n"
-    "ldr d3, [x3, #0xb8]\n"
-    "ssubl v4.8h, v4.8b, v14.8b\n"
-    "usubl v26.8h, v26.8b, v9.8b\n"
-    "ssubl v0.8h, v0.8b, v14.8b\n"
-    "smlal v15.4s, v24.4h, v4.4h\n"
-    "smlal2 v18.4s, v24.8h, v4.8h\n"
-    "ldr d24, [x23, x10]\n"
-    "smlal v16.4s, v22.4h, v4.4h\n"
-    "smlal2 v21.4s, v22.8h, v4.8h\n"
-    "smlal v7.4s, v28.4h, v4.4h\n"
-    "smlal2 v17.4s, v28.8h, v4.8h\n"
-    "smlal v8.4s, v26.4h, v4.4h\n"
+    "ldr x11, [x20, #0xe0]\n"
+    "ldr x17, [x20, #0xe8]\n"
+    "smlal2 v16.4s, v29.8h, v0.8h\n"
     "smlal2 v5.4s, v26.8h, v4.8h\n"
-    "ldr d4, [x3, #0xc0]\n"
-    "smlal v15.4s, v27.4h, v0.4h\n"
-    "smlal2 v18.4s, v27.8h, v0.8h\n"
-    "ldr d27, [x22, x10]\n"
-    "smlal v16.4s, v23.4h, v0.4h\n"
-    "smlal2 v21.4s, v23.8h, v0.8h\n"
-    "usubl v25.8h, v25.8b, v9.8b\n"
-    "usubl v24.8h, v24.8b, v9.8b\n"
+    "ldr d4, [x23, #0x48]\n"
+    "ssubl v4.8h, v4.8b, v14.8b\n"
+    "smlal v17.4s, v28.4h, v0.4h\n"
+    "smlal v10.4s, v22.4h, v0.4h\n"
+    "ldr x5, [x20, #0xf0]\n"
+    "ldr x25, [x20, #0xf8]\n"
+    "smlal v6.4s, v25.4h, v0.4h\n"
+    "smlal2 v8.4s, v28.8h, v0.8h\n"
+    "ldr q12, [x10, #0x0]\n"
+    "ldr q19, [x1, #0x0]\n"
+    "smlal2 v7.4s, v22.8h, v0.8h\n"
+    "smlal v15.4s, v28.4h, v1.4h\n"
+    "ldr q20, [x10, #0x10]\n"
+    "ldr q29, [x1, #0x10]\n"
+    "smlal2 v16.4s, v28.8h, v1.8h\n"
+    "smlal2 v5.4s, v25.8h, v0.8h\n"
+    "ldr d28, [x26, x24]\n"
+    "ldr d0, [x23, #0x50]\n"
+    "smlal v17.4s, v23.4h, v1.4h\n"
+    "smlal v10.4s, v25.4h, v1.4h\n"
+    "usubl v28.8h, v28.8b, v9.8b\n"
+    "ldr x26, [x20, #0x100]\n"
+    "smlal v6.4s, v24.4h, v1.4h\n"
+    "smlal2 v8.4s, v23.8h, v1.8h\n"
+    "ssubl v0.8h, v0.8b, v14.8b\n"
+    "tst x0, #0x7\n"
+    "smlal2 v7.4s, v25.8h, v1.8h\n"
+    "smlal v15.4s, v23.4h, v2.4h\n"
+    "add x10, x10, #0x20\n"
+    "add x1, x1, #0x20\n"
+    "smlal2 v16.4s, v23.8h, v2.8h\n"
+    "ldr d23, [x12, x24]\n"
+    "smlal2 v5.4s, v24.8h, v1.8h\n"
+    "usubl v23.8h, v23.8b, v9.8b\n"
+    "smlal v17.4s, v31.4h, v2.4h\n"
+    "smlal v10.4s, v24.4h, v2.4h\n"
+    "ldr d1, [x23, #0x58]\n"
     "ssubl v1.8h, v1.8b, v14.8b\n"
-    "smlal v7.4s, v25.4h, v0.4h\n"
-    "smlal2 v17.4s, v25.8h, v0.8h\n"
-    "ldr d25, [x20, x10]\n"
-    "smlal v8.4s, v24.4h, v0.4h\n"
+    "smlal v6.4s, v27.4h, v2.4h\n"
+    "smlal2 v8.4s, v31.8h, v2.8h\n"
+    "ldr x12, [x20, #0x108]\n"
+    "smlal2 v7.4s, v24.8h, v2.8h\n"
+    "smlal v15.4s, v31.4h, v3.4h\n"
+    "smlal2 v16.4s, v31.8h, v3.8h\n"
+    "smlal2 v5.4s, v27.8h, v2.8h\n"
+    "ldr d31, [x14, x24]\n"
+    "usubl v31.8h, v31.8b, v9.8b\n"
+    "smlal v17.4s, v30.4h, v3.4h\n"
+    "smlal v10.4s, v27.4h, v3.4h\n"
+    "ldr d2, [x23, #0x60]\n"
+    "ssubl v2.8h, v2.8b, v14.8b\n"
+    "smlal v6.4s, v23.4h, v3.4h\n"
+    "smlal2 v8.4s, v30.8h, v3.8h\n"
+    "ldr x14, [x20, #0x110]\n"
+    "smlal2 v7.4s, v27.8h, v3.8h\n"
+    "smlal v15.4s, v30.4h, v4.4h\n"
+    "smlal2 v16.4s, v30.8h, v4.8h\n"
+    "ldr d30, [x15, x24]\n"
+    "smlal2 v5.4s, v23.8h, v3.8h\n"
+    "usubl v30.8h, v30.8b, v9.8b\n"
+    "smlal v17.4s, v26.4h, v4.4h\n"
+    "smlal v10.4s, v23.4h, v4.4h\n"
+    "ldr d3, [x23, #0x68]\n"
+    "ssubl v3.8h, v3.8b, v14.8b\n"
+    "smlal v6.4s, v28.4h, v4.4h\n"
+    "smlal2 v8.4s, v26.8h, v4.8h\n"
+    "ldr d26, [x21, x24]\n"
+    "usubl v26.8h, v26.8b, v9.8b\n"
+    "smlal2 v7.4s, v23.8h, v4.8h\n"
+    "smlal v15.4s, v22.4h, v0.4h\n"
+    "ldr x21, [x20, #0x118]\n"
+    "smlal2 v16.4s, v22.8h, v0.8h\n"
+    "smlal2 v5.4s, v28.8h, v4.8h\n"
+    "ldr d4, [x23, #0x70]\n"
+    "ldr d22, [x9, x24]\n"
+    "smlal v17.4s, v25.4h, v0.4h\n"
+    "smlal v10.4s, v31.4h, v0.4h\n"
+    "ssubl v4.8h, v4.8b, v14.8b\n"
+    "smlal v6.4s, v30.4h, v0.4h\n"
+    "smlal2 v8.4s, v25.8h, v0.8h\n"
+    "usubl v22.8h, v22.8b, v9.8b\n"
+    "smlal2 v7.4s, v31.8h, v0.8h\n"
+    "smlal v15.4s, v25.4h, v1.4h\n"
+    "smlal2 v16.4s, v25.8h, v1.8h\n"
+    "ldr d25, [x2, x24]\n"
+    "smlal2 v5.4s, v30.8h, v0.8h\n"
+    "usubl v25.8h, v25.8b, v9.8b\n"
+    "smlal v17.4s, v24.4h, v1.4h\n"
+    "smlal v10.4s, v30.4h, v1.4h\n"
+    "ldr d0, [x23, #0x78]\n"
+    "ssubl v0.8h, v0.8b, v14.8b\n"
+    "smlal v6.4s, v26.4h, v1.4h\n"
+    "smlal2 v8.4s, v24.8h, v1.8h\n"
+    "smlal2 v7.4s, v30.8h, v1.8h\n"
+    "smlal v15.4s, v24.4h, v2.4h\n"
+    "smlal2 v16.4s, v24.8h, v2.8h\n"
+    "ldr d24, [x13, x24]\n"
+    "smlal2 v5.4s, v26.8h, v1.8h\n"
+    "usubl v24.8h, v24.8b, v9.8b\n"
+    "smlal v17.4s, v27.4h, v2.4h\n"
+    "smlal v10.4s, v26.4h, v2.4h\n"
+    "ldr d1, [x23, #0x80]\n"
+    "ssubl v1.8h, v1.8b, v14.8b\n"
+    "smlal v6.4s, v25.4h, v2.4h\n"
+    "smlal2 v8.4s, v27.8h, v2.8h\n"
+    "smlal2 v7.4s, v26.8h, v2.8h\n"
+    "smlal v15.4s, v27.4h, v3.4h\n"
+    "smlal2 v16.4s, v27.8h, v3.8h\n"
+    "smlal2 v5.4s, v25.8h, v2.8h\n"
+    "ldr d27, [x19, x24]\n"
+    "usubl v27.8h, v27.8b, v9.8b\n"
+    "smlal v17.4s, v23.4h, v3.4h\n"
+    "smlal v10.4s, v25.4h, v3.4h\n"
+    "ldr d2, [x23, #0x88]\n"
+    "ssubl v2.8h, v2.8b, v14.8b\n"
+    "smlal v6.4s, v24.4h, v3.4h\n"
+    "smlal2 v8.4s, v23.8h, v3.8h\n"
+    "smlal2 v7.4s, v25.8h, v3.8h\n"
+    "smlal v15.4s, v23.4h, v4.4h\n"
+    "smlal2 v16.4s, v23.8h, v4.8h\n"
+    "ldr d23, [x28, x24]\n"
+    "smlal2 v5.4s, v24.8h, v3.8h\n"
+    "usubl v23.8h, v23.8b, v9.8b\n"
+    "smlal v17.4s, v28.4h, v4.4h\n"
+    "smlal v10.4s, v24.4h, v4.4h\n"
+    "ldr d3, [x23, #0x90]\n"
+    "ssubl v3.8h, v3.8b, v14.8b\n"
+    "smlal v6.4s, v22.4h, v4.4h\n"
+    "smlal2 v8.4s, v28.8h, v4.8h\n"
+    "ldr d28, [x11, x24]\n"
+    "usubl v28.8h, v28.8b, v9.8b\n"
+    "smlal2 v7.4s, v24.8h, v4.8h\n"
+    "smlal v15.4s, v31.4h, v0.4h\n"
+    "smlal2 v16.4s, v31.8h, v0.8h\n"
+    "ldr d31, [x6, x24]\n"
+    "smlal2 v5.4s, v22.8h, v4.8h\n"
+    "usubl v31.8h, v31.8b, v9.8b\n"
+    "smlal v17.4s, v30.4h, v0.4h\n"
+    "smlal v10.4s, v27.4h, v0.4h\n"
+    "ldr d4, [x23, #0x98]\n"
+    "ssubl v4.8h, v4.8b, v14.8b\n"
+    "smlal v6.4s, v23.4h, v0.4h\n"
+    "smlal2 v8.4s, v30.8h, v0.8h\n"
+    "smlal2 v7.4s, v27.8h, v0.8h\n"
+    "smlal v15.4s, v30.4h, v1.4h\n"
+    "smlal2 v16.4s, v30.8h, v1.8h\n"
+    "ldr d30, [x27, x24]\n"
+    "smlal2 v5.4s, v23.8h, v0.8h\n"
+    "usubl v30.8h, v30.8b, v9.8b\n"
+    "smlal v17.4s, v26.4h, v1.4h\n"
+    "smlal v10.4s, v23.4h, v1.4h\n"
+    "ldr d0, [x23, #0xa0]\n"
+    "ssubl v0.8h, v0.8b, v14.8b\n"
+    "smlal v6.4s, v31.4h, v1.4h\n"
+    "smlal2 v8.4s, v26.8h, v1.8h\n"
+    "smlal2 v7.4s, v23.8h, v1.8h\n"
+    "smlal v15.4s, v26.4h, v2.4h\n"
+    "smlal2 v16.4s, v26.8h, v2.8h\n"
+    "smlal2 v5.4s, v31.8h, v1.8h\n"
+    "ldr d26, [x17, x24]\n"
+    "usubl v26.8h, v26.8b, v9.8b\n"
+    "smlal v17.4s, v25.4h, v2.4h\n"
+    "smlal v10.4s, v31.4h, v2.4h\n"
+    "ldr d1, [x23, #0xa8]\n"
+    "ssubl v1.8h, v1.8b, v14.8b\n"
+    "smlal v6.4s, v30.4h, v2.4h\n"
+    "smlal2 v8.4s, v25.8h, v2.8h\n"
+    "smlal2 v7.4s, v31.8h, v2.8h\n"
+    "smlal v15.4s, v25.4h, v3.4h\n"
+    "smlal2 v16.4s, v25.8h, v3.8h\n"
+    "smlal2 v5.4s, v30.8h, v2.8h\n"
+    "ldr d25, [x5, x24]\n"
+    "usubl v25.8h, v25.8b, v9.8b\n"
+    "smlal v17.4s, v24.4h, v3.4h\n"
+    "smlal v10.4s, v30.4h, v3.4h\n"
+    "ldr d2, [x23, #0xb0]\n"
+    "ssubl v2.8h, v2.8b, v14.8b\n"
+    "smlal v6.4s, v28.4h, v3.4h\n"
+    "smlal2 v8.4s, v24.8h, v3.8h\n"
+    "smlal2 v7.4s, v30.8h, v3.8h\n"
+    "smlal v15.4s, v24.4h, v4.4h\n"
+    "smlal2 v16.4s, v24.8h, v4.8h\n"
+    "ldr d24, [x25, x24]\n"
+    "smlal2 v5.4s, v28.8h, v3.8h\n"
+    "usubl v24.8h, v24.8b, v9.8b\n"
+    "smlal v17.4s, v22.4h, v4.4h\n"
+    "smlal v10.4s, v28.4h, v4.4h\n"
+    "ldr d3, [x23, #0xb8]\n"
+    "ssubl v3.8h, v3.8b, v14.8b\n"
+    "smlal v6.4s, v26.4h, v4.4h\n"
+    "smlal2 v7.4s, v28.8h, v4.8h\n"
+    "smlal v15.4s, v27.4h, v0.4h\n"
+    "smlal2 v16.4s, v27.8h, v0.8h\n"
+    "ldr d27, [x26, x24]\n"
+    "usubl v27.8h, v27.8b, v9.8b\n"
+    "smlal2 v8.4s, v22.8h, v4.8h\n"
+    "smlal2 v5.4s, v26.8h, v4.8h\n"
+    "ldr d4, [x23, #0xc0]\n"
+    "ssubl v4.8h, v4.8b, v14.8b\n"
+    "smlal v17.4s, v23.4h, v0.4h\n"
+    "smlal v10.4s, v25.4h, v0.4h\n"
+    "smlal v6.4s, v24.4h, v0.4h\n"
+    "smlal2 v7.4s, v25.8h, v0.8h\n"
+    "ldr d25, [x12, x24]\n"
+    "usubl v25.8h, v25.8b, v9.8b\n"
+    "smlal2 v8.4s, v23.8h, v0.8h\n"
     "smlal2 v5.4s, v24.8h, v0.8h\n"
     "smlal v15.4s, v23.4h, v1.4h\n"
-    "smlal2 v18.4s, v23.8h, v1.8h\n"
-    "smlal v16.4s, v31.4h, v1.4h\n"
-    "smlal2 v21.4s, v31.8h, v1.8h\n"
-    "smlal v7.4s, v24.4h, v1.4h\n"
-    "smlal2 v17.4s, v24.8h, v1.8h\n"
-    "ldr d24, [x13, x10]\n"
-    "usubl v27.8h, v27.8b, v9.8b\n"
-    "ssubl v2.8h, v2.8b, v14.8b\n"
-    "usubl v25.8h, v25.8b, v9.8b\n"
-    "smlal v8.4s, v27.4h, v1.4h\n"
+    "smlal v17.4s, v31.4h, v1.4h\n"
+    "smlal v10.4s, v24.4h, v1.4h\n"
+    "smlal v6.4s, v27.4h, v1.4h\n"
+    "smlal2 v7.4s, v24.8h, v1.8h\n"
+    "ldr d24, [x14, x24]\n"
+    "smlal2 v16.4s, v23.8h, v1.8h\n"
+    "usubl v24.8h, v24.8b, v9.8b\n"
+    "smlal2 v8.4s, v31.8h, v1.8h\n"
     "smlal2 v5.4s, v27.8h, v1.8h\n"
     "smlal v15.4s, v31.4h, v2.4h\n"
-    "smlal2 v18.4s, v31.8h, v2.8h\n"
-    "smlal v16.4s, v30.4h, v2.4h\n"
-    "smlal2 v21.4s, v30.8h, v2.8h\n"
-    "smlal v7.4s, v27.4h, v2.4h\n"
-    "smlal2 v17.4s, v27.8h, v2.8h\n"
-    "ldr d27, [x21, x10]\n"
-    "add x10, x10, #0x8\n"
-    "smlal v8.4s, v25.4h, v2.4h\n"
+    "smlal v17.4s, v30.4h, v2.4h\n"
+    "smlal v10.4s, v27.4h, v2.4h\n"
+    "smlal v6.4s, v25.4h, v2.4h\n"
+    "smlal2 v7.4s, v27.8h, v2.8h\n"
+    "ldr d27, [x21, x24]\n"
+    "smlal2 v16.4s, v31.8h, v2.8h\n"
+    "usubl v27.8h, v27.8b, v9.8b\n"
+    "smlal2 v8.4s, v30.8h, v2.8h\n"
     "smlal2 v5.4s, v25.8h, v2.8h\n"
-    "ssubl v3.8h, v3.8b, v14.8b\n"
-    "usubl v24.8h, v24.8b, v9.8b\n"
-    "ssubl v4.8h, v4.8b, v14.8b\n"
+    "add x24, x24, #0x8\n"
     "smlal v15.4s, v30.4h, v3.4h\n"
-    "smlal2 v18.4s, v30.8h, v3.8h\n"
-    "smlal v16.4s, v28.4h, v3.4h\n"
-    "smlal2 v21.4s, v28.8h, v3.8h\n"
-    "smlal v7.4s, v25.4h, v3.4h\n"
-    "smlal2 v17.4s, v25.8h, v3.8h\n"
-    "smlal v8.4s, v24.4h, v3.4h\n"
+    "smlal v17.4s, v28.4h, v3.4h\n"
+    "smlal v10.4s, v25.4h, v3.4h\n"
+    "smlal v6.4s, v24.4h, v3.4h\n"
+    "smlal2 v16.4s, v30.8h, v3.8h\n"
+    "smlal2 v8.4s, v28.8h, v3.8h\n"
+    "smlal2 v7.4s, v25.8h, v3.8h\n"
     "smlal2 v5.4s, v24.8h, v3.8h\n"
     "smlal v15.4s, v28.4h, v4.4h\n"
-    "smlal2 v18.4s, v28.8h, v4.8h\n"
-    "smlal v16.4s, v26.4h, v4.4h\n"
-    "smlal2 v21.4s, v26.8h, v4.8h\n"
-    "smlal v7.4s, v24.4h, v4.4h\n"
-    "smlal2 v17.4s, v24.8h, v4.8h\n"
-    "usubl v27.8h, v27.8b, v9.8b\n"
-    "sqrdmulh v15.4s, v15.4s, v6.4s\n"
-    "sqrdmulh v18.4s, v18.4s, v20.4s\n"
-    "smlal v8.4s, v27.4h, v4.4h\n"
+    "smlal v17.4s, v26.4h, v4.4h\n"
+    "sqdmulh v15.4s, v15.4s, v12.4s\n"
+    "smlal v10.4s, v24.4h, v4.4h\n"
+    "smlal v6.4s, v27.4h, v4.4h\n"
+    "sqdmulh v17.4s, v17.4s, v12.4s\n"
+    "smlal2 v16.4s, v28.8h, v4.8h\n"
+    "smlal2 v8.4s, v26.8h, v4.8h\n"
+    "sqdmulh v10.4s, v10.4s, v12.4s\n"
+    "smlal2 v7.4s, v24.8h, v4.8h\n"
     "smlal2 v5.4s, v27.8h, v4.8h\n"
-    "and v28.16b, v15.16b, v19.16b\n"
-    "and v26.16b, v18.16b, v12.16b\n"
-    "sqrdmulh v16.4s, v16.4s, v6.4s\n"
-    "sshr v28.4s, v28.4s, #0x1f\n"
+    "sqdmulh v6.4s, v6.4s, v12.4s\n"
+    "and v23.16b, v15.16b, v19.16b\n"
+    "sqdmulh v16.4s, v16.4s, v20.4s\n"
+    "and v22.16b, v17.16b, v19.16b\n"
+    "sqdmulh v8.4s, v8.4s, v20.4s\n"
+    "and v21.16b, v10.16b, v19.16b\n"
+    "sqdmulh v7.4s, v7.4s, v20.4s\n"
+    "and v26.16b, v6.16b, v19.16b\n"
+    "sqdmulh v5.4s, v5.4s, v20.4s\n"
+    "sshr v23.4s, v23.4s, #0x1f\n"
+    "and v4.16b, v16.16b, v29.16b\n"
+    "sshr v22.4s, v22.4s, #0x1f\n"
+    "and v2.16b, v8.16b, v29.16b\n"
+    "sshr v21.4s, v21.4s, #0x1f\n"
+    "and v3.16b, v7.16b, v29.16b\n"
     "sshr v26.4s, v26.4s, #0x1f\n"
-    "sqrdmulh v21.4s, v21.4s, v20.4s\n"
-    "sqadd v15.4s, v15.4s, v28.4s\n"
-    "sqadd v18.4s, v18.4s, v26.4s\n"
-    "and v29.16b, v16.16b, v19.16b\n"
-    "and v4.16b, v21.16b, v12.16b\n"
-    "srshl v15.4s, v15.4s, v19.4s\n"
-    "srshl v18.4s, v18.4s, v12.4s\n"
-    "sshr v29.4s, v29.4s, #0x1f\n"
+    "and v25.16b, v5.16b, v29.16b\n"
+    "sqadd v15.4s, v15.4s, v23.4s\n"
     "sshr v4.4s, v4.4s, #0x1f\n"
-    "add v15.4s, v15.4s, v10.4s\n"
-    "add v18.4s, v18.4s, v10.4s\n"
-    "sqadd v16.4s, v16.4s, v29.4s\n"
-    "smin v15.4s, v15.4s, v13.4s\n"
-    "smin v18.4s, v18.4s, v13.4s\n"
-    "sqadd v21.4s, v21.4s, v4.4s\n"
-    "smax v15.4s, v15.4s, v11.4s\n"
-    "smax v18.4s, v18.4s, v11.4s\n"
-    "srshl v16.4s, v16.4s, v19.4s\n"
-    "srshl v21.4s, v21.4s, v12.4s\n"
-    "uzp1 v15.16b, v15.16b, v18.16b\n"
-    "sqrdmulh v7.4s, v7.4s, v6.4s\n"
-    "uzp1 v15.16b, v15.16b, v15.16b\n"
-    "str d15, [x17, x1]\n"
-    "add v16.4s, v16.4s, v10.4s\n"
-    "add v21.4s, v21.4s, v10.4s\n"
-    "and v25.16b, v7.16b, v19.16b\n"
-    "sqrdmulh v17.4s, v17.4s, v20.4s\n"
-    "smin v16.4s, v16.4s, v13.4s\n"
-    "smin v21.4s, v21.4s, v13.4s\n"
+    "sqadd v17.4s, v17.4s, v22.4s\n"
+    "sshr v2.4s, v2.4s, #0x1f\n"
+    "sqadd v10.4s, v10.4s, v21.4s\n"
+    "sshr v3.4s, v3.4s, #0x1f\n"
+    "sqadd v6.4s, v6.4s, v26.4s\n"
     "sshr v25.4s, v25.4s, #0x1f\n"
-    "smax v16.4s, v16.4s, v11.4s\n"
-    "smax v21.4s, v21.4s, v11.4s\n"
-    "sqadd v7.4s, v7.4s, v25.4s\n"
-    "and v31.16b, v17.16b, v12.16b\n"
-    "uzp1 v16.16b, v16.16b, v21.16b\n"
-    "sqrdmulh v8.4s, v8.4s, v6.4s\n"
-    "uzp1 v16.16b, v16.16b, v16.16b\n"
-    "str d16, [x16, x1]\n"
-    "srshl v7.4s, v7.4s, v19.4s\n"
-    "sshr v31.4s, v31.4s, #0x1f\n"
-    "and v24.16b, v8.16b, v19.16b\n"
-    "sqrdmulh v5.4s, v5.4s, v20.4s\n"
-    "sqadd v17.4s, v17.4s, v31.4s\n"
-    "add v7.4s, v7.4s, v10.4s\n"
-    "sshr v24.4s, v24.4s, #0x1f\n"
-    "and v1.16b, v5.16b, v12.16b\n"
-    "smin v7.4s, v7.4s, v13.4s\n"
-    "srshl v17.4s, v17.4s, v12.4s\n"
-    "sqadd v8.4s, v8.4s, v24.4s\n"
-    "smax v7.4s, v7.4s, v11.4s\n"
-    "sshr v1.4s, v1.4s, #0x1f\n"
-    "add v17.4s, v17.4s, v10.4s\n"
-    "srshl v8.4s, v8.4s, v19.4s\n"
-    "sqadd v5.4s, v5.4s, v1.4s\n"
-    "smin v17.4s, v17.4s, v13.4s\n"
-    "add v8.4s, v8.4s, v10.4s\n"
-    "smax v17.4s, v17.4s, v11.4s\n"
-    "srshl v5.4s, v5.4s, v12.4s\n"
-    "smin v8.4s, v8.4s, v13.4s\n"
-    "uzp1 v7.16b, v7.16b, v17.16b\n"
-    "add v5.4s, v5.4s, v10.4s\n"
-    "uzp1 v7.16b, v7.16b, v7.16b\n"
-    "str d7, [x6, x1]\n"
-    "smax v8.4s, v8.4s, v11.4s\n"
-    "smin v5.4s, v5.4s, v13.4s\n"
-    "smax v5.4s, v5.4s, v11.4s\n"
-    "uzp1 v8.16b, v8.16b, v5.16b\n"
-    "uzp1 v8.16b, v8.16b, v8.16b\n"
-    "str d8, [x8, x1]\n"
-    "add x1, x1, #0x8\n"
+    "srshl v15.4s, v15.4s, v19.4s\n"
+    "sqadd v16.4s, v16.4s, v4.4s\n"
+    "srshl v17.4s, v17.4s, v19.4s\n"
+    "sqadd v8.4s, v8.4s, v2.4s\n"
+    "srshl v10.4s, v10.4s, v19.4s\n"
+    "sqadd v7.4s, v7.4s, v3.4s\n"
+    "srshl v6.4s, v6.4s, v19.4s\n"
+    "sqadd v5.4s, v5.4s, v25.4s\n"
+    "srshl v16.4s, v16.4s, v29.4s\n"
+    "sqxtn v15.4h, v15.4s\n"
+    "srshl v8.4s, v8.4s, v29.4s\n"
+    "sqxtn v17.4h, v17.4s\n"
+    "srshl v7.4s, v7.4s, v29.4s\n"
+    "sqxtn v10.4h, v10.4s\n"
+    "srshl v5.4s, v5.4s, v29.4s\n"
+    "sqxtn v6.4h, v6.4s\n"
+    "sqxtn2 v15.8h, v16.4s\n"
+    "sqxtn2 v17.8h, v8.4s\n"
+    "sqxtn2 v10.8h, v7.4s\n"
+    "sqxtn2 v6.8h, v5.4s\n"
+    "sqadd v15.8h, v15.8h, v18.8h\n"
+    "sqadd v17.8h, v17.8h, v18.8h\n"
+    "sqadd v10.8h, v10.8h, v18.8h\n"
+    "sqadd v6.8h, v6.8h, v18.8h\n"
+    "smax v15.8h, v15.8h, v11.8h\n"
+    "smax v17.8h, v17.8h, v11.8h\n"
+    "smax v10.8h, v10.8h, v11.8h\n"
+    "smax v6.8h, v6.8h, v11.8h\n"
+    "smin v15.8h, v15.8h, v13.8h\n"
+    "smin v17.8h, v17.8h, v13.8h\n"
+    "smin v10.8h, v10.8h, v13.8h\n"
+    "smin v6.8h, v6.8h, v13.8h\n"
+    "uzp1 v15.16b, v15.16b, v15.16b\n"
+    "uzp1 v17.16b, v17.16b, v17.16b\n"
+    "str d15, [x16, x22]\n"
+    "uzp1 v10.16b, v10.16b, v10.16b\n"
+    "uzp1 v6.16b, v6.16b, v6.16b\n"
+    "str d17, [x8, x22]\n"
+    "str d10, [x4, x22]\n"
+    "str d6, [x7, x22]\n"
+    "add x22, x22, #0x8\n"
     "beq 124f\n"
-    "add x3, x3, #0xc8\n"
+    "add x23, x23, #0xc8\n"
     "3:"  // Oddments
-    "ldr x12, [%x[params], %[offsetof_Params_bias]]\n"
-    "tbz x4, #2, 5f\n"
-    "ld1 { v15.4s }, [x12], #0x10\n"
-    "tbz x4, #1, 4f\n"
-    "ld1 { v18.d }[0], [x12], #0x8\n"
-    "tbz x4, #0, 7f\n"
-    "ld1 { v18.s }[2], [x12]\n"
+    "ldr x19, [%x[params], %[offsetof_Params_bias]]\n"
+    "tbz x0, #2, 5f\n"
+    "ld1 { v15.4s }, [x19], #0x10\n"
+    "tbz x0, #1, 4f\n"
+    "ld1 { v16.d }[0], [x19], #0x8\n"
+    "tbz x0, #0, 7f\n"
+    "ld1 { v16.s }[2], [x19]\n"
     "b 7f\n"
     "4:"  // Oddments: Load bias: Bit 2: Bit 1: Unset
-    "tbz x4, #0, 7f\n"
-    "ld1 { v18.s }[0], [x12]\n"
+    "tbz x0, #0, 7f\n"
+    "ld1 { v16.s }[0], [x19]\n"
     "b 7f\n"
     "5:"  // Oddments: Load bias: Bit 2: Unset
-    "tbz x4, #1, 6f\n"
-    "ld1 { v15.d }[0], [x12], #0x8\n"
-    "tbz x4, #0, 7f\n"
-    "ld1 { v15.s }[2], [x12]\n"
+    "tbz x0, #1, 6f\n"
+    "ld1 { v15.d }[0], [x19], #0x8\n"
+    "tbz x0, #0, 7f\n"
+    "ld1 { v15.s }[2], [x19]\n"
     "b 7f\n"
     "6:"  // Oddments: Load bias: Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 7f\n"
-    "ld1 { v15.s }[0], [x12]\n"
+    "tbz x0, #0, 7f\n"
+    "ld1 { v15.s }[0], [x19]\n"
     "7:"  // Oddments: Load bias: Bit 2: End
-    "mov v16.16b, v15.16b\n"
-    "ldr d0, [x3, #0x0]\n"
-    "mov v21.16b, v18.16b\n"
-    "ldr d1, [x3, #0x8]\n"
-    "mov v7.16b, v15.16b\n"
-    "ldr d2, [x3, #0x10]\n"
-    "mov v17.16b, v18.16b\n"
-    "ldr d3, [x3, #0x18]\n"
-    "mov v8.16b, v15.16b\n"
-    "ldr d4, [x3, #0x20]\n"
-    "mov v5.16b, v18.16b\n"
-    "ldp x28, x27, [x25, #0x0]\n"
+    "ldr d0, [x23, #0x0]\n"
+    "ldr d1, [x23, #0x8]\n"
+    "mov v17.16b, v15.16b\n"
+    "mov v8.16b, v16.16b\n"
+    "ldr d2, [x23, #0x10]\n"
+    "ldr d3, [x23, #0x18]\n"
+    "mov v10.16b, v15.16b\n"
+    "mov v7.16b, v16.16b\n"
+    "ldr d4, [x23, #0x20]\n"
+    "ldp x28, x6, [x20, #0x0]\n"
+    "mov v6.16b, v15.16b\n"
+    "mov v5.16b, v16.16b\n"
+    "ldp x26, x25, [x20, #0x10]\n"
+    "ldp x5, x2, [x20, #0x20]\n"
     "ssubl v0.8h, v0.8b, v14.8b\n"
-    "ldp x26, x13, [x25, #0x10]\n"
     "ssubl v1.8h, v1.8b, v14.8b\n"
+    "ldp x27, x21, [x20, #0x30]\n"
+    "ldp x12, x19, [x20, #0x40]\n"
     "ssubl v2.8h, v2.8b, v14.8b\n"
-    "ldp x24, x23, [x25, #0x20]\n"
     "ssubl v3.8h, v3.8b, v14.8b\n"
     "ssubl v4.8h, v4.8b, v14.8b\n"
-    "ldp x22, x21, [x25, #0x30]\n"
-    "ldp x20, x0, [x25, #0x40]\n"
-    "add x28, x28, x10\n"
-    "add x27, x27, x10\n"
-    "add x26, x26, x10\n"
-    "add x13, x13, x10\n"
-    "add x24, x24, x10\n"
-    "add x23, x23, x10\n"
-    "add x22, x22, x10\n"
-    "add x21, x21, x10\n"
-    "add x20, x20, x10\n"
-    "add x0, x0, x10\n"
-    "tbz x4, #2, 9f\n"
+    "add x28, x28, x24\n"
+    "add x6, x6, x24\n"
+    "add x26, x26, x24\n"
+    "add x25, x25, x24\n"
+    "add x5, x5, x24\n"
+    "add x2, x2, x24\n"
+    "add x27, x27, x24\n"
+    "add x21, x21, x24\n"
+    "add x12, x12, x24\n"
+    "add x19, x19, x24\n"
+    "tbz x0, #2, 9f\n"
     "ld1 { v31.s }[0], [x28], #0x4\n"
-    "ld1 { v30.s }[0], [x27], #0x4\n"
+    "ld1 { v30.s }[0], [x6], #0x4\n"
     "ld1 { v29.s }[0], [x26], #0x4\n"
-    "ld1 { v28.s }[0], [x13], #0x4\n"
-    "ld1 { v27.s }[0], [x24], #0x4\n"
-    "ld1 { v23.s }[0], [x23], #0x4\n"
-    "ld1 { v25.s }[0], [x22], #0x4\n"
+    "ld1 { v28.s }[0], [x25], #0x4\n"
+    "ld1 { v27.s }[0], [x5], #0x4\n"
+    "ld1 { v23.s }[0], [x2], #0x4\n"
+    "ld1 { v25.s }[0], [x27], #0x4\n"
     "ld1 { v24.s }[0], [x21], #0x4\n"
-    "ld1 { v26.s }[0], [x20], #0x4\n"
-    "ld1 { v22.s }[0], [x0], #0x4\n"
-    "tbz x4, #1, 8f\n"
+    "ld1 { v26.s }[0], [x12], #0x4\n"
+    "ld1 { v22.s }[0], [x19], #0x4\n"
+    "tbz x0, #1, 8f\n"
     "ld1 { v31.h }[2], [x28], #0x2\n"
-    "ld1 { v30.h }[2], [x27], #0x2\n"
+    "ld1 { v30.h }[2], [x6], #0x2\n"
     "ld1 { v29.h }[2], [x26], #0x2\n"
-    "ld1 { v28.h }[2], [x13], #0x2\n"
-    "ld1 { v27.h }[2], [x24], #0x2\n"
-    "ld1 { v23.h }[2], [x23], #0x2\n"
-    "ld1 { v25.h }[2], [x22], #0x2\n"
+    "ld1 { v28.h }[2], [x25], #0x2\n"
+    "ld1 { v27.h }[2], [x5], #0x2\n"
+    "ld1 { v23.h }[2], [x2], #0x2\n"
+    "ld1 { v25.h }[2], [x27], #0x2\n"
     "ld1 { v24.h }[2], [x21], #0x2\n"
-    "ld1 { v26.h }[2], [x20], #0x2\n"
-    "ld1 { v22.h }[2], [x0], #0x2\n"
-    "tbz x4, #0, 11f\n"
+    "ld1 { v26.h }[2], [x12], #0x2\n"
+    "ld1 { v22.h }[2], [x19], #0x2\n"
+    "tbz x0, #0, 11f\n"
     "ld1 { v31.b }[6], [x28]\n"
-    "ld1 { v30.b }[6], [x27]\n"
+    "ld1 { v30.b }[6], [x6]\n"
     "ld1 { v29.b }[6], [x26]\n"
-    "ld1 { v28.b }[6], [x13]\n"
-    "ld1 { v27.b }[6], [x24]\n"
-    "ld1 { v23.b }[6], [x23]\n"
-    "ld1 { v25.b }[6], [x22]\n"
+    "ld1 { v28.b }[6], [x25]\n"
+    "ld1 { v27.b }[6], [x5]\n"
+    "ld1 { v23.b }[6], [x2]\n"
+    "ld1 { v25.b }[6], [x27]\n"
     "ld1 { v24.b }[6], [x21]\n"
-    "ld1 { v26.b }[6], [x20]\n"
-    "ld1 { v22.b }[6], [x0]\n"
+    "ld1 { v26.b }[6], [x12]\n"
+    "ld1 { v22.b }[6], [x19]\n"
     "b 11f\n"
     "8:"  // Oddments: Initial loads: Bit 2: Bit 1: Unset
-    "tbz x4, #0, 11f\n"
+    "tbz x0, #0, 11f\n"
     "ld1 { v31.b }[4], [x28]\n"
-    "ld1 { v30.b }[4], [x27]\n"
+    "ld1 { v30.b }[4], [x6]\n"
     "ld1 { v29.b }[4], [x26]\n"
-    "ld1 { v28.b }[4], [x13]\n"
-    "ld1 { v27.b }[4], [x24]\n"
-    "ld1 { v23.b }[4], [x23]\n"
-    "ld1 { v25.b }[4], [x22]\n"
+    "ld1 { v28.b }[4], [x25]\n"
+    "ld1 { v27.b }[4], [x5]\n"
+    "ld1 { v23.b }[4], [x2]\n"
+    "ld1 { v25.b }[4], [x27]\n"
     "ld1 { v24.b }[4], [x21]\n"
-    "ld1 { v26.b }[4], [x20]\n"
-    "ld1 { v22.b }[4], [x0]\n"
+    "ld1 { v26.b }[4], [x12]\n"
+    "ld1 { v22.b }[4], [x19]\n"
     "b 11f\n"
     "9:"  // Oddments: Initial loads: Bit 2: Unset
-    "tbz x4, #1, 10f\n"
+    "tbz x0, #1, 10f\n"
     "ld1 { v31.h }[0], [x28], #0x2\n"
-    "ld1 { v30.h }[0], [x27], #0x2\n"
+    "ld1 { v30.h }[0], [x6], #0x2\n"
     "ld1 { v29.h }[0], [x26], #0x2\n"
-    "ld1 { v28.h }[0], [x13], #0x2\n"
-    "ld1 { v27.h }[0], [x24], #0x2\n"
-    "ld1 { v23.h }[0], [x23], #0x2\n"
-    "ld1 { v25.h }[0], [x22], #0x2\n"
+    "ld1 { v28.h }[0], [x25], #0x2\n"
+    "ld1 { v27.h }[0], [x5], #0x2\n"
+    "ld1 { v23.h }[0], [x2], #0x2\n"
+    "ld1 { v25.h }[0], [x27], #0x2\n"
     "ld1 { v24.h }[0], [x21], #0x2\n"
-    "ld1 { v26.h }[0], [x20], #0x2\n"
-    "ld1 { v22.h }[0], [x0], #0x2\n"
-    "tbz x4, #0, 11f\n"
+    "ld1 { v26.h }[0], [x12], #0x2\n"
+    "ld1 { v22.h }[0], [x19], #0x2\n"
+    "tbz x0, #0, 11f\n"
     "ld1 { v31.b }[2], [x28]\n"
-    "ld1 { v30.b }[2], [x27]\n"
+    "ld1 { v30.b }[2], [x6]\n"
     "ld1 { v29.b }[2], [x26]\n"
-    "ld1 { v28.b }[2], [x13]\n"
-    "ld1 { v27.b }[2], [x24]\n"
-    "ld1 { v23.b }[2], [x23]\n"
-    "ld1 { v25.b }[2], [x22]\n"
+    "ld1 { v28.b }[2], [x25]\n"
+    "ld1 { v27.b }[2], [x5]\n"
+    "ld1 { v23.b }[2], [x2]\n"
+    "ld1 { v25.b }[2], [x27]\n"
     "ld1 { v24.b }[2], [x21]\n"
-    "ld1 { v26.b }[2], [x20]\n"
-    "ld1 { v22.b }[2], [x0]\n"
+    "ld1 { v26.b }[2], [x12]\n"
+    "ld1 { v22.b }[2], [x19]\n"
     "b 11f\n"
     "10:"  // Oddments: Initial loads: Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 11f\n"
+    "tbz x0, #0, 11f\n"
     "ld1 { v31.b }[0], [x28]\n"
-    "ld1 { v30.b }[0], [x27]\n"
+    "ld1 { v30.b }[0], [x6]\n"
     "ld1 { v29.b }[0], [x26]\n"
-    "ld1 { v28.b }[0], [x13]\n"
-    "ld1 { v27.b }[0], [x24]\n"
-    "ld1 { v23.b }[0], [x23]\n"
-    "ld1 { v25.b }[0], [x22]\n"
+    "ld1 { v28.b }[0], [x25]\n"
+    "ld1 { v27.b }[0], [x5]\n"
+    "ld1 { v23.b }[0], [x2]\n"
+    "ld1 { v25.b }[0], [x27]\n"
     "ld1 { v24.b }[0], [x21]\n"
-    "ld1 { v26.b }[0], [x20]\n"
-    "ld1 { v22.b }[0], [x0]\n"
+    "ld1 { v26.b }[0], [x12]\n"
+    "ld1 { v22.b }[0], [x19]\n"
     "11:"  // Oddments: Initial loads: Bit 2: End
     "usubl v31.8h, v31.8b, v9.8b\n"
-    "ldr x20, [x25, #0x50]\n"
-    "add x20, x20, x10\n"
     "usubl v30.8h, v30.8b, v9.8b\n"
-    "usubl v29.8h, v29.8b, v9.8b\n"
-    "usubl v28.8h, v28.8b, v9.8b\n"
-    "usubl v27.8h, v27.8b, v9.8b\n"
-    "usubl v23.8h, v23.8b, v9.8b\n"
-    "usubl v25.8h, v25.8b, v9.8b\n"
-    "usubl v24.8h, v24.8b, v9.8b\n"
-    "usubl v26.8h, v26.8b, v9.8b\n"
-    "usubl v22.8h, v22.8b, v9.8b\n"
     "smlal v15.4s, v31.4h, v0.4h\n"
-    "smlal2 v18.4s, v31.8h, v0.8h\n"
-    "smlal v16.4s, v30.4h, v0.4h\n"
-    "smlal2 v21.4s, v30.8h, v0.8h\n"
-    "smlal v7.4s, v29.4h, v0.4h\n"
-    "smlal2 v17.4s, v29.8h, v0.8h\n"
-    "smlal v8.4s, v28.4h, v0.4h\n"
+    "ldr x19, [x20, #0x50]\n"
+    "usubl v29.8h, v29.8b, v9.8b\n"
+    "smlal2 v16.4s, v31.8h, v0.8h\n"
+    "smlal v17.4s, v30.4h, v0.4h\n"
+    "smlal2 v8.4s, v30.8h, v0.8h\n"
+    "smlal v10.4s, v29.4h, v0.4h\n"
+    "usubl v28.8h, v28.8b, v9.8b\n"
+    "add x19, x19, x24\n"
+    "smlal2 v7.4s, v29.8h, v0.8h\n"
+    "usubl v27.8h, v27.8b, v9.8b\n"
+    "smlal v6.4s, v28.4h, v0.4h\n"
     "smlal2 v5.4s, v28.8h, v0.8h\n"
     "smlal v15.4s, v30.4h, v1.4h\n"
-    "smlal2 v18.4s, v30.8h, v1.8h\n"
-    "smlal v16.4s, v27.4h, v1.4h\n"
-    "smlal2 v21.4s, v27.8h, v1.8h\n"
-    "smlal v7.4s, v28.4h, v1.4h\n"
-    "smlal2 v17.4s, v28.8h, v1.8h\n"
-    "smlal v8.4s, v23.4h, v1.4h\n"
+    "usubl v23.8h, v23.8b, v9.8b\n"
+    "smlal2 v16.4s, v30.8h, v1.8h\n"
+    "smlal v17.4s, v27.4h, v1.4h\n"
+    "usubl v25.8h, v25.8b, v9.8b\n"
+    "smlal2 v8.4s, v27.8h, v1.8h\n"
+    "smlal v10.4s, v28.4h, v1.4h\n"
+    "usubl v24.8h, v24.8b, v9.8b\n"
+    "smlal2 v7.4s, v28.8h, v1.8h\n"
+    "usubl v26.8h, v26.8b, v9.8b\n"
+    "smlal v6.4s, v23.4h, v1.4h\n"
+    "usubl v22.8h, v22.8b, v9.8b\n"
     "smlal2 v5.4s, v23.8h, v1.8h\n"
     "smlal v15.4s, v27.4h, v2.4h\n"
-    "smlal2 v18.4s, v27.8h, v2.8h\n"
-    "smlal v16.4s, v25.4h, v2.4h\n"
-    "smlal2 v21.4s, v25.8h, v2.8h\n"
-    "smlal v7.4s, v23.4h, v2.4h\n"
-    "smlal2 v17.4s, v23.8h, v2.8h\n"
-    "tbz x4, #2, 13f\n"
-    "ld1 { v31.s }[0], [x20], #0x4\n"
-    "tbz x4, #1, 12f\n"
-    "ld1 { v31.h }[2], [x20], #0x2\n"
-    "tbz x4, #0, 15f\n"
-    "ld1 { v31.b }[6], [x20]\n"
+    "smlal2 v16.4s, v27.8h, v2.8h\n"
+    "smlal v17.4s, v25.4h, v2.4h\n"
+    "smlal2 v8.4s, v25.8h, v2.8h\n"
+    "smlal v10.4s, v23.4h, v2.4h\n"
+    "smlal2 v7.4s, v23.8h, v2.8h\n"
+    "tbz x0, #2, 13f\n"
+    "ld1 { v31.s }[0], [x19], #0x4\n"
+    "tbz x0, #1, 12f\n"
+    "ld1 { v31.h }[2], [x19], #0x2\n"
+    "tbz x0, #0, 15f\n"
+    "ld1 { v31.b }[6], [x19]\n"
     "b 15f\n"
     "12:"  // Oddments: Load (1, 3): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 15f\n"
-    "ld1 { v31.b }[4], [x20]\n"
+    "tbz x0, #0, 15f\n"
+    "ld1 { v31.b }[4], [x19]\n"
     "b 15f\n"
     "13:"  // Oddments: Load (1, 3): Bit 2: Unset
-    "tbz x4, #1, 14f\n"
-    "ld1 { v31.h }[0], [x20], #0x2\n"
-    "tbz x4, #0, 15f\n"
-    "ld1 { v31.b }[2], [x20]\n"
+    "tbz x0, #1, 14f\n"
+    "ld1 { v31.h }[0], [x19], #0x2\n"
+    "tbz x0, #0, 15f\n"
+    "ld1 { v31.b }[2], [x19]\n"
     "b 15f\n"
     "14:"  // Oddments: Load (1, 3): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 15f\n"
-    "ld1 { v31.b }[0], [x20]\n"
+    "tbz x0, #0, 15f\n"
+    "ld1 { v31.b }[0], [x19]\n"
     "15:"  // Oddments: Load (1, 3): Bit 2: End
     "usubl v31.8h, v31.8b, v9.8b\n"
-    "ldr x28, [x25, #0x58]\n"
-    "smlal v15.4s, v25.4h, v3.4h\n"
-    "add x28, x28, x10\n"
-    "smlal v8.4s, v31.4h, v2.4h\n"
+    "ldr x15, [x20, #0x58]\n"
+    "smlal v6.4s, v31.4h, v2.4h\n"
     "smlal2 v5.4s, v31.8h, v2.8h\n"
-    "smlal2 v18.4s, v25.8h, v3.8h\n"
-    "smlal v16.4s, v24.4h, v3.4h\n"
-    "smlal2 v21.4s, v24.8h, v3.8h\n"
-    "smlal v7.4s, v31.4h, v3.4h\n"
-    "smlal2 v17.4s, v31.8h, v3.8h\n"
-    "tbz x4, #2, 17f\n"
-    "ld1 { v30.s }[0], [x28], #0x4\n"
-    "tbz x4, #1, 16f\n"
-    "ld1 { v30.h }[2], [x28], #0x2\n"
-    "tbz x4, #0, 19f\n"
-    "ld1 { v30.b }[6], [x28]\n"
+    "smlal v15.4s, v25.4h, v3.4h\n"
+    "smlal2 v16.4s, v25.8h, v3.8h\n"
+    "add x15, x15, x24\n"
+    "smlal v17.4s, v24.4h, v3.4h\n"
+    "smlal2 v8.4s, v24.8h, v3.8h\n"
+    "smlal v10.4s, v31.4h, v3.4h\n"
+    "smlal2 v7.4s, v31.8h, v3.8h\n"
+    "tbz x0, #2, 17f\n"
+    "ld1 { v30.s }[0], [x15], #0x4\n"
+    "tbz x0, #1, 16f\n"
+    "ld1 { v30.h }[2], [x15], #0x2\n"
+    "tbz x0, #0, 19f\n"
+    "ld1 { v30.b }[6], [x15]\n"
     "b 19f\n"
     "16:"  // Oddments: Load (1, 4): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 19f\n"
-    "ld1 { v30.b }[4], [x28]\n"
+    "tbz x0, #0, 19f\n"
+    "ld1 { v30.b }[4], [x15]\n"
     "b 19f\n"
     "17:"  // Oddments: Load (1, 4): Bit 2: Unset
-    "tbz x4, #1, 18f\n"
-    "ld1 { v30.h }[0], [x28], #0x2\n"
-    "tbz x4, #0, 19f\n"
-    "ld1 { v30.b }[2], [x28]\n"
+    "tbz x0, #1, 18f\n"
+    "ld1 { v30.h }[0], [x15], #0x2\n"
+    "tbz x0, #0, 19f\n"
+    "ld1 { v30.b }[2], [x15]\n"
     "b 19f\n"
     "18:"  // Oddments: Load (1, 4): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 19f\n"
-    "ld1 { v30.b }[0], [x28]\n"
+    "tbz x0, #0, 19f\n"
+    "ld1 { v30.b }[0], [x15]\n"
     "19:"  // Oddments: Load (1, 4): Bit 2: End
     "usubl v30.8h, v30.8b, v9.8b\n"
-    "ldr x0, [x25, #0x60]\n"
-    "smlal v15.4s, v24.4h, v4.4h\n"
-    "add x0, x0, x10\n"
-    "smlal v8.4s, v30.4h, v3.4h\n"
+    "ldr x19, [x20, #0x60]\n"
+    "smlal v6.4s, v30.4h, v3.4h\n"
     "smlal2 v5.4s, v30.8h, v3.8h\n"
-    "smlal2 v18.4s, v24.8h, v4.8h\n"
-    "tbz x4, #2, 21f\n"
-    "ld1 { v27.s }[0], [x0], #0x4\n"
-    "tbz x4, #1, 20f\n"
-    "ld1 { v27.h }[2], [x0], #0x2\n"
-    "tbz x4, #0, 23f\n"
-    "ld1 { v27.b }[6], [x0]\n"
+    "smlal v15.4s, v24.4h, v4.4h\n"
+    "smlal2 v16.4s, v24.8h, v4.8h\n"
+    "add x19, x19, x24\n"
+    "tbz x0, #2, 21f\n"
+    "ld1 { v27.s }[0], [x19], #0x4\n"
+    "tbz x0, #1, 20f\n"
+    "ld1 { v27.h }[2], [x19], #0x2\n"
+    "tbz x0, #0, 23f\n"
+    "ld1 { v27.b }[6], [x19]\n"
     "b 23f\n"
     "20:"  // Oddments: Load (0, 5): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 23f\n"
-    "ld1 { v27.b }[4], [x0]\n"
+    "tbz x0, #0, 23f\n"
+    "ld1 { v27.b }[4], [x19]\n"
     "b 23f\n"
     "21:"  // Oddments: Load (0, 5): Bit 2: Unset
-    "tbz x4, #1, 22f\n"
-    "ld1 { v27.h }[0], [x0], #0x2\n"
-    "tbz x4, #0, 23f\n"
-    "ld1 { v27.b }[2], [x0]\n"
+    "tbz x0, #1, 22f\n"
+    "ld1 { v27.h }[0], [x19], #0x2\n"
+    "tbz x0, #0, 23f\n"
+    "ld1 { v27.b }[2], [x19]\n"
     "b 23f\n"
     "22:"  // Oddments: Load (0, 5): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 23f\n"
-    "ld1 { v27.b }[0], [x0]\n"
+    "tbz x0, #0, 23f\n"
+    "ld1 { v27.b }[0], [x19]\n"
     "23:"  // Oddments: Load (0, 5): Bit 2: End
     "usubl v27.8h, v27.8b, v9.8b\n"
-    "ldr d0, [x3, #0x28]\n"
-    "smlal v7.4s, v30.4h, v4.4h\n"
-    "ldr x7, [x25, #0x68]\n"
-    "add x7, x7, x10\n"
-    "smlal v16.4s, v27.4h, v4.4h\n"
-    "smlal2 v21.4s, v27.8h, v4.8h\n"
-    "smlal2 v17.4s, v30.8h, v4.8h\n"
-    "smlal v8.4s, v26.4h, v4.4h\n"
-    "smlal2 v5.4s, v26.8h, v4.8h\n"
+    "ldr d0, [x23, #0x28]\n"
+    "smlal v17.4s, v27.4h, v4.4h\n"
+    "smlal2 v8.4s, v27.8h, v4.8h\n"
+    "smlal v10.4s, v30.4h, v4.4h\n"
+    "smlal2 v7.4s, v30.8h, v4.8h\n"
     "ssubl v0.8h, v0.8b, v14.8b\n"
+    "ldr x27, [x20, #0x68]\n"
+    "smlal v6.4s, v26.4h, v4.4h\n"
+    "smlal2 v5.4s, v26.8h, v4.8h\n"
+    "add x27, x27, x24\n"
     "smlal v15.4s, v29.4h, v0.4h\n"
-    "smlal2 v18.4s, v29.8h, v0.8h\n"
-    "smlal v16.4s, v28.4h, v0.4h\n"
-    "smlal2 v21.4s, v28.8h, v0.8h\n"
-    "smlal v7.4s, v22.4h, v0.4h\n"
-    "smlal2 v17.4s, v22.8h, v0.8h\n"
-    "tbz x4, #2, 25f\n"
-    "ld1 { v25.s }[0], [x7], #0x4\n"
-    "tbz x4, #1, 24f\n"
-    "ld1 { v25.h }[2], [x7], #0x2\n"
-    "tbz x4, #0, 27f\n"
-    "ld1 { v25.b }[6], [x7]\n"
+    "smlal2 v16.4s, v29.8h, v0.8h\n"
+    "smlal v17.4s, v28.4h, v0.4h\n"
+    "smlal2 v8.4s, v28.8h, v0.8h\n"
+    "smlal v10.4s, v22.4h, v0.4h\n"
+    "smlal2 v7.4s, v22.8h, v0.8h\n"
+    "tbz x0, #2, 25f\n"
+    "ld1 { v25.s }[0], [x27], #0x4\n"
+    "tbz x0, #1, 24f\n"
+    "ld1 { v25.h }[2], [x27], #0x2\n"
+    "tbz x0, #0, 27f\n"
+    "ld1 { v25.b }[6], [x27]\n"
     "b 27f\n"
     "24:"  // Oddments: Load (2, 1): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 27f\n"
-    "ld1 { v25.b }[4], [x7]\n"
+    "tbz x0, #0, 27f\n"
+    "ld1 { v25.b }[4], [x27]\n"
     "b 27f\n"
     "25:"  // Oddments: Load (2, 1): Bit 2: Unset
-    "tbz x4, #1, 26f\n"
-    "ld1 { v25.h }[0], [x7], #0x2\n"
-    "tbz x4, #0, 27f\n"
-    "ld1 { v25.b }[2], [x7]\n"
+    "tbz x0, #1, 26f\n"
+    "ld1 { v25.h }[0], [x27], #0x2\n"
+    "tbz x0, #0, 27f\n"
+    "ld1 { v25.b }[2], [x27]\n"
     "b 27f\n"
     "26:"  // Oddments: Load (2, 1): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 27f\n"
-    "ld1 { v25.b }[0], [x7]\n"
+    "tbz x0, #0, 27f\n"
+    "ld1 { v25.b }[0], [x27]\n"
     "27:"  // Oddments: Load (2, 1): Bit 2: End
+    "ldr d1, [x23, #0x30]\n"
     "usubl v25.8h, v25.8b, v9.8b\n"
-    "ldr d1, [x3, #0x30]\n"
-    "smlal v8.4s, v25.4h, v0.4h\n"
-    "ldr x26, [x25, #0x70]\n"
-    "add x26, x26, x10\n"
-    "smlal2 v5.4s, v25.8h, v0.8h\n"
     "ssubl v1.8h, v1.8b, v14.8b\n"
+    "ldr x5, [x20, #0x70]\n"
+    "smlal v6.4s, v25.4h, v0.4h\n"
+    "smlal2 v5.4s, v25.8h, v0.8h\n"
+    "add x5, x5, x24\n"
     "smlal v15.4s, v28.4h, v1.4h\n"
-    "smlal2 v18.4s, v28.8h, v1.8h\n"
-    "smlal v16.4s, v23.4h, v1.4h\n"
-    "smlal2 v21.4s, v23.8h, v1.8h\n"
-    "smlal v7.4s, v25.4h, v1.4h\n"
-    "smlal2 v17.4s, v25.8h, v1.8h\n"
-    "tbz x4, #2, 29f\n"
-    "ld1 { v24.s }[0], [x26], #0x4\n"
-    "tbz x4, #1, 28f\n"
-    "ld1 { v24.h }[2], [x26], #0x2\n"
-    "tbz x4, #0, 31f\n"
-    "ld1 { v24.b }[6], [x26]\n"
+    "smlal2 v16.4s, v28.8h, v1.8h\n"
+    "smlal v17.4s, v23.4h, v1.4h\n"
+    "smlal2 v8.4s, v23.8h, v1.8h\n"
+    "smlal v10.4s, v25.4h, v1.4h\n"
+    "smlal2 v7.4s, v25.8h, v1.8h\n"
+    "tbz x0, #2, 29f\n"
+    "ld1 { v24.s }[0], [x5], #0x4\n"
+    "tbz x0, #1, 28f\n"
+    "ld1 { v24.h }[2], [x5], #0x2\n"
+    "tbz x0, #0, 31f\n"
+    "ld1 { v24.b }[6], [x5]\n"
     "b 31f\n"
     "28:"  // Oddments: Load (2, 2): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 31f\n"
-    "ld1 { v24.b }[4], [x26]\n"
+    "tbz x0, #0, 31f\n"
+    "ld1 { v24.b }[4], [x5]\n"
     "b 31f\n"
     "29:"  // Oddments: Load (2, 2): Bit 2: Unset
-    "tbz x4, #1, 30f\n"
-    "ld1 { v24.h }[0], [x26], #0x2\n"
-    "tbz x4, #0, 31f\n"
-    "ld1 { v24.b }[2], [x26]\n"
+    "tbz x0, #1, 30f\n"
+    "ld1 { v24.h }[0], [x5], #0x2\n"
+    "tbz x0, #0, 31f\n"
+    "ld1 { v24.b }[2], [x5]\n"
     "b 31f\n"
     "30:"  // Oddments: Load (2, 2): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 31f\n"
-    "ld1 { v24.b }[0], [x26]\n"
+    "tbz x0, #0, 31f\n"
+    "ld1 { v24.b }[0], [x5]\n"
     "31:"  // Oddments: Load (2, 2): Bit 2: End
+    "ldr d2, [x23, #0x38]\n"
     "usubl v24.8h, v24.8b, v9.8b\n"
-    "ldr d2, [x3, #0x38]\n"
-    "smlal v8.4s, v24.4h, v1.4h\n"
-    "ldr x23, [x25, #0x78]\n"
-    "add x23, x23, x10\n"
-    "smlal2 v5.4s, v24.8h, v1.8h\n"
     "ssubl v2.8h, v2.8b, v14.8b\n"
+    "ldr x11, [x20, #0x78]\n"
+    "smlal v6.4s, v24.4h, v1.4h\n"
+    "smlal2 v5.4s, v24.8h, v1.8h\n"
+    "add x11, x11, x24\n"
     "smlal v15.4s, v23.4h, v2.4h\n"
-    "smlal2 v18.4s, v23.8h, v2.8h\n"
-    "smlal v16.4s, v31.4h, v2.4h\n"
-    "smlal2 v21.4s, v31.8h, v2.8h\n"
-    "smlal v7.4s, v24.4h, v2.4h\n"
-    "smlal2 v17.4s, v24.8h, v2.8h\n"
-    "tbz x4, #2, 33f\n"
-    "ld1 { v27.s }[0], [x23], #0x4\n"
-    "tbz x4, #1, 32f\n"
-    "ld1 { v27.h }[2], [x23], #0x2\n"
-    "tbz x4, #0, 35f\n"
-    "ld1 { v27.b }[6], [x23]\n"
+    "smlal2 v16.4s, v23.8h, v2.8h\n"
+    "smlal v17.4s, v31.4h, v2.4h\n"
+    "smlal2 v8.4s, v31.8h, v2.8h\n"
+    "smlal v10.4s, v24.4h, v2.4h\n"
+    "smlal2 v7.4s, v24.8h, v2.8h\n"
+    "tbz x0, #2, 33f\n"
+    "ld1 { v27.s }[0], [x11], #0x4\n"
+    "tbz x0, #1, 32f\n"
+    "ld1 { v27.h }[2], [x11], #0x2\n"
+    "tbz x0, #0, 35f\n"
+    "ld1 { v27.b }[6], [x11]\n"
     "b 35f\n"
     "32:"  // Oddments: Load (2, 3): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 35f\n"
-    "ld1 { v27.b }[4], [x23]\n"
+    "tbz x0, #0, 35f\n"
+    "ld1 { v27.b }[4], [x11]\n"
     "b 35f\n"
     "33:"  // Oddments: Load (2, 3): Bit 2: Unset
-    "tbz x4, #1, 34f\n"
-    "ld1 { v27.h }[0], [x23], #0x2\n"
-    "tbz x4, #0, 35f\n"
-    "ld1 { v27.b }[2], [x23]\n"
+    "tbz x0, #1, 34f\n"
+    "ld1 { v27.h }[0], [x11], #0x2\n"
+    "tbz x0, #0, 35f\n"
+    "ld1 { v27.b }[2], [x11]\n"
     "b 35f\n"
     "34:"  // Oddments: Load (2, 3): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 35f\n"
-    "ld1 { v27.b }[0], [x23]\n"
+    "tbz x0, #0, 35f\n"
+    "ld1 { v27.b }[0], [x11]\n"
     "35:"  // Oddments: Load (2, 3): Bit 2: End
+    "ldr d3, [x23, #0x40]\n"
     "usubl v27.8h, v27.8b, v9.8b\n"
-    "ldr d3, [x3, #0x40]\n"
-    "smlal v8.4s, v27.4h, v2.4h\n"
-    "ldr x20, [x25, #0x80]\n"
-    "add x20, x20, x10\n"
-    "smlal2 v5.4s, v27.8h, v2.8h\n"
     "ssubl v3.8h, v3.8b, v14.8b\n"
+    "ldr x12, [x20, #0x80]\n"
+    "smlal v6.4s, v27.4h, v2.4h\n"
+    "smlal2 v5.4s, v27.8h, v2.8h\n"
+    "add x12, x12, x24\n"
     "smlal v15.4s, v31.4h, v3.4h\n"
-    "smlal2 v18.4s, v31.8h, v3.8h\n"
-    "smlal v16.4s, v30.4h, v3.4h\n"
-    "smlal2 v21.4s, v30.8h, v3.8h\n"
-    "smlal v7.4s, v27.4h, v3.4h\n"
-    "smlal2 v17.4s, v27.8h, v3.8h\n"
-    "tbz x4, #2, 37f\n"
-    "ld1 { v23.s }[0], [x20], #0x4\n"
-    "tbz x4, #1, 36f\n"
-    "ld1 { v23.h }[2], [x20], #0x2\n"
-    "tbz x4, #0, 39f\n"
-    "ld1 { v23.b }[6], [x20]\n"
+    "smlal2 v16.4s, v31.8h, v3.8h\n"
+    "smlal v17.4s, v30.4h, v3.4h\n"
+    "smlal2 v8.4s, v30.8h, v3.8h\n"
+    "smlal v10.4s, v27.4h, v3.4h\n"
+    "smlal2 v7.4s, v27.8h, v3.8h\n"
+    "tbz x0, #2, 37f\n"
+    "ld1 { v23.s }[0], [x12], #0x4\n"
+    "tbz x0, #1, 36f\n"
+    "ld1 { v23.h }[2], [x12], #0x2\n"
+    "tbz x0, #0, 39f\n"
+    "ld1 { v23.b }[6], [x12]\n"
     "b 39f\n"
     "36:"  // Oddments: Load (2, 4): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 39f\n"
-    "ld1 { v23.b }[4], [x20]\n"
+    "tbz x0, #0, 39f\n"
+    "ld1 { v23.b }[4], [x12]\n"
     "b 39f\n"
     "37:"  // Oddments: Load (2, 4): Bit 2: Unset
-    "tbz x4, #1, 38f\n"
-    "ld1 { v23.h }[0], [x20], #0x2\n"
-    "tbz x4, #0, 39f\n"
-    "ld1 { v23.b }[2], [x20]\n"
+    "tbz x0, #1, 38f\n"
+    "ld1 { v23.h }[0], [x12], #0x2\n"
+    "tbz x0, #0, 39f\n"
+    "ld1 { v23.b }[2], [x12]\n"
     "b 39f\n"
     "38:"  // Oddments: Load (2, 4): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 39f\n"
-    "ld1 { v23.b }[0], [x20]\n"
+    "tbz x0, #0, 39f\n"
+    "ld1 { v23.b }[0], [x12]\n"
     "39:"  // Oddments: Load (2, 4): Bit 2: End
+    "ldr d4, [x23, #0x48]\n"
     "usubl v23.8h, v23.8b, v9.8b\n"
-    "ldr d4, [x3, #0x48]\n"
-    "smlal v8.4s, v23.4h, v3.4h\n"
-    "ldr x22, [x25, #0x88]\n"
-    "add x22, x22, x10\n"
-    "smlal2 v5.4s, v23.8h, v3.8h\n"
     "ssubl v4.8h, v4.8b, v14.8b\n"
+    "ldr x26, [x20, #0x88]\n"
+    "smlal v6.4s, v23.4h, v3.4h\n"
+    "smlal2 v5.4s, v23.8h, v3.8h\n"
+    "add x26, x26, x24\n"
     "smlal v15.4s, v30.4h, v4.4h\n"
-    "smlal2 v18.4s, v30.8h, v4.8h\n"
-    "smlal v16.4s, v26.4h, v4.4h\n"
-    "smlal2 v21.4s, v26.8h, v4.8h\n"
-    "smlal v7.4s, v23.4h, v4.4h\n"
-    "smlal2 v17.4s, v23.8h, v4.8h\n"
-    "tbz x4, #2, 41f\n"
-    "ld1 { v28.s }[0], [x22], #0x4\n"
-    "tbz x4, #1, 40f\n"
-    "ld1 { v28.h }[2], [x22], #0x2\n"
-    "tbz x4, #0, 43f\n"
-    "ld1 { v28.b }[6], [x22]\n"
+    "smlal2 v16.4s, v30.8h, v4.8h\n"
+    "smlal v17.4s, v26.4h, v4.4h\n"
+    "smlal2 v8.4s, v26.8h, v4.8h\n"
+    "smlal v10.4s, v23.4h, v4.4h\n"
+    "smlal2 v7.4s, v23.8h, v4.8h\n"
+    "tbz x0, #2, 41f\n"
+    "ld1 { v28.s }[0], [x26], #0x4\n"
+    "tbz x0, #1, 40f\n"
+    "ld1 { v28.h }[2], [x26], #0x2\n"
+    "tbz x0, #0, 43f\n"
+    "ld1 { v28.b }[6], [x26]\n"
     "b 43f\n"
     "40:"  // Oddments: Load (2, 5): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 43f\n"
-    "ld1 { v28.b }[4], [x22]\n"
+    "tbz x0, #0, 43f\n"
+    "ld1 { v28.b }[4], [x26]\n"
     "b 43f\n"
     "41:"  // Oddments: Load (2, 5): Bit 2: Unset
-    "tbz x4, #1, 42f\n"
-    "ld1 { v28.h }[0], [x22], #0x2\n"
-    "tbz x4, #0, 43f\n"
-    "ld1 { v28.b }[2], [x22]\n"
+    "tbz x0, #1, 42f\n"
+    "ld1 { v28.h }[0], [x26], #0x2\n"
+    "tbz x0, #0, 43f\n"
+    "ld1 { v28.b }[2], [x26]\n"
     "b 43f\n"
     "42:"  // Oddments: Load (2, 5): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 43f\n"
-    "ld1 { v28.b }[0], [x22]\n"
+    "tbz x0, #0, 43f\n"
+    "ld1 { v28.b }[0], [x26]\n"
     "43:"  // Oddments: Load (2, 5): Bit 2: End
+    "ldr d0, [x23, #0x50]\n"
     "usubl v28.8h, v28.8b, v9.8b\n"
-    "ldr d0, [x3, #0x50]\n"
-    "smlal v8.4s, v28.4h, v4.4h\n"
-    "ldr x13, [x25, #0x90]\n"
-    "add x13, x13, x10\n"
-    "smlal2 v5.4s, v28.8h, v4.8h\n"
     "ssubl v0.8h, v0.8b, v14.8b\n"
+    "ldr x14, [x20, #0x90]\n"
+    "smlal v6.4s, v28.4h, v4.4h\n"
+    "smlal2 v5.4s, v28.8h, v4.8h\n"
+    "add x14, x14, x24\n"
     "smlal v15.4s, v22.4h, v0.4h\n"
-    "smlal2 v18.4s, v22.8h, v0.8h\n"
-    "smlal v16.4s, v25.4h, v0.4h\n"
-    "smlal2 v21.4s, v25.8h, v0.8h\n"
-    "tbz x4, #2, 45f\n"
-    "ld1 { v31.s }[0], [x13], #0x4\n"
-    "tbz x4, #1, 44f\n"
-    "ld1 { v31.h }[2], [x13], #0x2\n"
-    "tbz x4, #0, 47f\n"
-    "ld1 { v31.b }[6], [x13]\n"
+    "smlal2 v16.4s, v22.8h, v0.8h\n"
+    "smlal v17.4s, v25.4h, v0.4h\n"
+    "smlal2 v8.4s, v25.8h, v0.8h\n"
+    "tbz x0, #2, 45f\n"
+    "ld1 { v31.s }[0], [x14], #0x4\n"
+    "tbz x0, #1, 44f\n"
+    "ld1 { v31.h }[2], [x14], #0x2\n"
+    "tbz x0, #0, 47f\n"
+    "ld1 { v31.b }[6], [x14]\n"
     "b 47f\n"
     "44:"  // Oddments: Load (3, 0): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 47f\n"
-    "ld1 { v31.b }[4], [x13]\n"
+    "tbz x0, #0, 47f\n"
+    "ld1 { v31.b }[4], [x14]\n"
     "b 47f\n"
     "45:"  // Oddments: Load (3, 0): Bit 2: Unset
-    "tbz x4, #1, 46f\n"
-    "ld1 { v31.h }[0], [x13], #0x2\n"
-    "tbz x4, #0, 47f\n"
-    "ld1 { v31.b }[2], [x13]\n"
+    "tbz x0, #1, 46f\n"
+    "ld1 { v31.h }[0], [x14], #0x2\n"
+    "tbz x0, #0, 47f\n"
+    "ld1 { v31.b }[2], [x14]\n"
     "b 47f\n"
     "46:"  // Oddments: Load (3, 0): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 47f\n"
-    "ld1 { v31.b }[0], [x13]\n"
+    "tbz x0, #0, 47f\n"
+    "ld1 { v31.b }[0], [x14]\n"
     "47:"  // Oddments: Load (3, 0): Bit 2: End
     "usubl v31.8h, v31.8b, v9.8b\n"
-    "ldr x21, [x25, #0x98]\n"
-    "smlal v7.4s, v31.4h, v0.4h\n"
-    "add x21, x21, x10\n"
-    "smlal2 v17.4s, v31.8h, v0.8h\n"
-    "tbz x4, #2, 49f\n"
-    "ld1 { v30.s }[0], [x21], #0x4\n"
-    "tbz x4, #1, 48f\n"
-    "ld1 { v30.h }[2], [x21], #0x2\n"
-    "tbz x4, #0, 51f\n"
-    "ld1 { v30.b }[6], [x21]\n"
+    "ldr x15, [x20, #0x98]\n"
+    "smlal v10.4s, v31.4h, v0.4h\n"
+    "smlal2 v7.4s, v31.8h, v0.8h\n"
+    "add x15, x15, x24\n"
+    "tbz x0, #2, 49f\n"
+    "ld1 { v30.s }[0], [x15], #0x4\n"
+    "tbz x0, #1, 48f\n"
+    "ld1 { v30.h }[2], [x15], #0x2\n"
+    "tbz x0, #0, 51f\n"
+    "ld1 { v30.b }[6], [x15]\n"
     "b 51f\n"
     "48:"  // Oddments: Load (3, 1): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 51f\n"
-    "ld1 { v30.b }[4], [x21]\n"
+    "tbz x0, #0, 51f\n"
+    "ld1 { v30.b }[4], [x15]\n"
     "b 51f\n"
     "49:"  // Oddments: Load (3, 1): Bit 2: Unset
-    "tbz x4, #1, 50f\n"
-    "ld1 { v30.h }[0], [x21], #0x2\n"
-    "tbz x4, #0, 51f\n"
-    "ld1 { v30.b }[2], [x21]\n"
+    "tbz x0, #1, 50f\n"
+    "ld1 { v30.h }[0], [x15], #0x2\n"
+    "tbz x0, #0, 51f\n"
+    "ld1 { v30.b }[2], [x15]\n"
     "b 51f\n"
     "50:"  // Oddments: Load (3, 1): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 51f\n"
-    "ld1 { v30.b }[0], [x21]\n"
+    "tbz x0, #0, 51f\n"
+    "ld1 { v30.b }[0], [x15]\n"
     "51:"  // Oddments: Load (3, 1): Bit 2: End
+    "ldr d1, [x23, #0x58]\n"
     "usubl v30.8h, v30.8b, v9.8b\n"
-    "ldr d1, [x3, #0x58]\n"
-    "smlal v8.4s, v30.4h, v0.4h\n"
-    "ldr x14, [x25, #0xa0]\n"
-    "add x14, x14, x10\n"
-    "smlal2 v5.4s, v30.8h, v0.8h\n"
     "ssubl v1.8h, v1.8b, v14.8b\n"
+    "ldr x21, [x20, #0xa0]\n"
+    "smlal v6.4s, v30.4h, v0.4h\n"
+    "smlal2 v5.4s, v30.8h, v0.8h\n"
+    "add x21, x21, x24\n"
     "smlal v15.4s, v25.4h, v1.4h\n"
-    "smlal2 v18.4s, v25.8h, v1.8h\n"
-    "smlal v16.4s, v24.4h, v1.4h\n"
-    "smlal2 v21.4s, v24.8h, v1.8h\n"
-    "smlal v7.4s, v30.4h, v1.4h\n"
-    "smlal2 v17.4s, v30.8h, v1.8h\n"
-    "tbz x4, #2, 53f\n"
-    "ld1 { v26.s }[0], [x14], #0x4\n"
-    "tbz x4, #1, 52f\n"
-    "ld1 { v26.h }[2], [x14], #0x2\n"
-    "tbz x4, #0, 55f\n"
-    "ld1 { v26.b }[6], [x14]\n"
+    "smlal2 v16.4s, v25.8h, v1.8h\n"
+    "smlal v17.4s, v24.4h, v1.4h\n"
+    "smlal2 v8.4s, v24.8h, v1.8h\n"
+    "smlal v10.4s, v30.4h, v1.4h\n"
+    "smlal2 v7.4s, v30.8h, v1.8h\n"
+    "tbz x0, #2, 53f\n"
+    "ld1 { v26.s }[0], [x21], #0x4\n"
+    "tbz x0, #1, 52f\n"
+    "ld1 { v26.h }[2], [x21], #0x2\n"
+    "tbz x0, #0, 55f\n"
+    "ld1 { v26.b }[6], [x21]\n"
     "b 55f\n"
     "52:"  // Oddments: Load (3, 2): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 55f\n"
-    "ld1 { v26.b }[4], [x14]\n"
+    "tbz x0, #0, 55f\n"
+    "ld1 { v26.b }[4], [x21]\n"
     "b 55f\n"
     "53:"  // Oddments: Load (3, 2): Bit 2: Unset
-    "tbz x4, #1, 54f\n"
-    "ld1 { v26.h }[0], [x14], #0x2\n"
-    "tbz x4, #0, 55f\n"
-    "ld1 { v26.b }[2], [x14]\n"
+    "tbz x0, #1, 54f\n"
+    "ld1 { v26.h }[0], [x21], #0x2\n"
+    "tbz x0, #0, 55f\n"
+    "ld1 { v26.b }[2], [x21]\n"
     "b 55f\n"
     "54:"  // Oddments: Load (3, 2): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 55f\n"
-    "ld1 { v26.b }[0], [x14]\n"
+    "tbz x0, #0, 55f\n"
+    "ld1 { v26.b }[0], [x21]\n"
     "55:"  // Oddments: Load (3, 2): Bit 2: End
+    "ldr d2, [x23, #0x60]\n"
     "usubl v26.8h, v26.8b, v9.8b\n"
-    "ldr d2, [x3, #0x60]\n"
-    "smlal v8.4s, v26.4h, v1.4h\n"
-    "ldr x11, [x25, #0xa8]\n"
-    "add x11, x11, x10\n"
-    "smlal2 v5.4s, v26.8h, v1.8h\n"
     "ssubl v2.8h, v2.8b, v14.8b\n"
+    "ldr x2, [x20, #0xa8]\n"
+    "smlal v6.4s, v26.4h, v1.4h\n"
+    "smlal2 v5.4s, v26.8h, v1.8h\n"
+    "add x2, x2, x24\n"
     "smlal v15.4s, v24.4h, v2.4h\n"
-    "smlal2 v18.4s, v24.8h, v2.8h\n"
-    "smlal v16.4s, v27.4h, v2.4h\n"
-    "smlal2 v21.4s, v27.8h, v2.8h\n"
-    "smlal v7.4s, v26.4h, v2.4h\n"
-    "smlal2 v17.4s, v26.8h, v2.8h\n"
-    "tbz x4, #2, 57f\n"
-    "ld1 { v25.s }[0], [x11], #0x4\n"
-    "tbz x4, #1, 56f\n"
-    "ld1 { v25.h }[2], [x11], #0x2\n"
-    "tbz x4, #0, 59f\n"
-    "ld1 { v25.b }[6], [x11]\n"
+    "smlal2 v16.4s, v24.8h, v2.8h\n"
+    "smlal v17.4s, v27.4h, v2.4h\n"
+    "smlal2 v8.4s, v27.8h, v2.8h\n"
+    "smlal v10.4s, v26.4h, v2.4h\n"
+    "smlal2 v7.4s, v26.8h, v2.8h\n"
+    "tbz x0, #2, 57f\n"
+    "ld1 { v25.s }[0], [x2], #0x4\n"
+    "tbz x0, #1, 56f\n"
+    "ld1 { v25.h }[2], [x2], #0x2\n"
+    "tbz x0, #0, 59f\n"
+    "ld1 { v25.b }[6], [x2]\n"
     "b 59f\n"
     "56:"  // Oddments: Load (3, 3): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 59f\n"
-    "ld1 { v25.b }[4], [x11]\n"
+    "tbz x0, #0, 59f\n"
+    "ld1 { v25.b }[4], [x2]\n"
     "b 59f\n"
     "57:"  // Oddments: Load (3, 3): Bit 2: Unset
-    "tbz x4, #1, 58f\n"
-    "ld1 { v25.h }[0], [x11], #0x2\n"
-    "tbz x4, #0, 59f\n"
-    "ld1 { v25.b }[2], [x11]\n"
+    "tbz x0, #1, 58f\n"
+    "ld1 { v25.h }[0], [x2], #0x2\n"
+    "tbz x0, #0, 59f\n"
+    "ld1 { v25.b }[2], [x2]\n"
     "b 59f\n"
     "58:"  // Oddments: Load (3, 3): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 59f\n"
-    "ld1 { v25.b }[0], [x11]\n"
+    "tbz x0, #0, 59f\n"
+    "ld1 { v25.b }[0], [x2]\n"
     "59:"  // Oddments: Load (3, 3): Bit 2: End
+    "ldr d3, [x23, #0x68]\n"
     "usubl v25.8h, v25.8b, v9.8b\n"
-    "ldr d3, [x3, #0x68]\n"
-    "smlal v8.4s, v25.4h, v2.4h\n"
-    "ldr x24, [x25, #0xb0]\n"
-    "add x24, x24, x10\n"
-    "smlal2 v5.4s, v25.8h, v2.8h\n"
     "ssubl v3.8h, v3.8b, v14.8b\n"
+    "ldr x13, [x20, #0xb0]\n"
+    "smlal v6.4s, v25.4h, v2.4h\n"
+    "smlal2 v5.4s, v25.8h, v2.8h\n"
+    "add x13, x13, x24\n"
     "smlal v15.4s, v27.4h, v3.4h\n"
-    "smlal2 v18.4s, v27.8h, v3.8h\n"
-    "smlal v16.4s, v23.4h, v3.4h\n"
-    "smlal2 v21.4s, v23.8h, v3.8h\n"
-    "smlal v7.4s, v25.4h, v3.4h\n"
-    "smlal2 v17.4s, v25.8h, v3.8h\n"
-    "tbz x4, #2, 61f\n"
-    "ld1 { v24.s }[0], [x24], #0x4\n"
-    "tbz x4, #1, 60f\n"
-    "ld1 { v24.h }[2], [x24], #0x2\n"
-    "tbz x4, #0, 63f\n"
-    "ld1 { v24.b }[6], [x24]\n"
+    "smlal2 v16.4s, v27.8h, v3.8h\n"
+    "smlal v17.4s, v23.4h, v3.4h\n"
+    "smlal2 v8.4s, v23.8h, v3.8h\n"
+    "smlal v10.4s, v25.4h, v3.4h\n"
+    "smlal2 v7.4s, v25.8h, v3.8h\n"
+    "tbz x0, #2, 61f\n"
+    "ld1 { v24.s }[0], [x13], #0x4\n"
+    "tbz x0, #1, 60f\n"
+    "ld1 { v24.h }[2], [x13], #0x2\n"
+    "tbz x0, #0, 63f\n"
+    "ld1 { v24.b }[6], [x13]\n"
     "b 63f\n"
     "60:"  // Oddments: Load (3, 4): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 63f\n"
-    "ld1 { v24.b }[4], [x24]\n"
+    "tbz x0, #0, 63f\n"
+    "ld1 { v24.b }[4], [x13]\n"
     "b 63f\n"
     "61:"  // Oddments: Load (3, 4): Bit 2: Unset
-    "tbz x4, #1, 62f\n"
-    "ld1 { v24.h }[0], [x24], #0x2\n"
-    "tbz x4, #0, 63f\n"
-    "ld1 { v24.b }[2], [x24]\n"
+    "tbz x0, #1, 62f\n"
+    "ld1 { v24.h }[0], [x13], #0x2\n"
+    "tbz x0, #0, 63f\n"
+    "ld1 { v24.b }[2], [x13]\n"
     "b 63f\n"
     "62:"  // Oddments: Load (3, 4): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 63f\n"
-    "ld1 { v24.b }[0], [x24]\n"
+    "tbz x0, #0, 63f\n"
+    "ld1 { v24.b }[0], [x13]\n"
     "63:"  // Oddments: Load (3, 4): Bit 2: End
+    "ldr d4, [x23, #0x70]\n"
     "usubl v24.8h, v24.8b, v9.8b\n"
-    "ldr d4, [x3, #0x70]\n"
-    "smlal v8.4s, v24.4h, v3.4h\n"
-    "ldr x0, [x25, #0xb8]\n"
-    "add x0, x0, x10\n"
-    "smlal2 v5.4s, v24.8h, v3.8h\n"
     "ssubl v4.8h, v4.8b, v14.8b\n"
+    "ldr x9, [x20, #0xb8]\n"
+    "smlal v6.4s, v24.4h, v3.4h\n"
+    "smlal2 v5.4s, v24.8h, v3.8h\n"
+    "add x9, x9, x24\n"
     "smlal v15.4s, v23.4h, v4.4h\n"
-    "smlal2 v18.4s, v23.8h, v4.8h\n"
-    "smlal v16.4s, v28.4h, v4.4h\n"
-    "smlal2 v21.4s, v28.8h, v4.8h\n"
-    "smlal v7.4s, v24.4h, v4.4h\n"
-    "smlal2 v17.4s, v24.8h, v4.8h\n"
-    "tbz x4, #2, 65f\n"
-    "ld1 { v22.s }[0], [x0], #0x4\n"
-    "tbz x4, #1, 64f\n"
-    "ld1 { v22.h }[2], [x0], #0x2\n"
-    "tbz x4, #0, 67f\n"
-    "ld1 { v22.b }[6], [x0]\n"
+    "smlal2 v16.4s, v23.8h, v4.8h\n"
+    "smlal v17.4s, v28.4h, v4.4h\n"
+    "smlal2 v8.4s, v28.8h, v4.8h\n"
+    "smlal v10.4s, v24.4h, v4.4h\n"
+    "smlal2 v7.4s, v24.8h, v4.8h\n"
+    "tbz x0, #2, 65f\n"
+    "ld1 { v22.s }[0], [x9], #0x4\n"
+    "tbz x0, #1, 64f\n"
+    "ld1 { v22.h }[2], [x9], #0x2\n"
+    "tbz x0, #0, 67f\n"
+    "ld1 { v22.b }[6], [x9]\n"
     "b 67f\n"
     "64:"  // Oddments: Load (3, 5): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 67f\n"
-    "ld1 { v22.b }[4], [x0]\n"
+    "tbz x0, #0, 67f\n"
+    "ld1 { v22.b }[4], [x9]\n"
     "b 67f\n"
     "65:"  // Oddments: Load (3, 5): Bit 2: Unset
-    "tbz x4, #1, 66f\n"
-    "ld1 { v22.h }[0], [x0], #0x2\n"
-    "tbz x4, #0, 67f\n"
-    "ld1 { v22.b }[2], [x0]\n"
+    "tbz x0, #1, 66f\n"
+    "ld1 { v22.h }[0], [x9], #0x2\n"
+    "tbz x0, #0, 67f\n"
+    "ld1 { v22.b }[2], [x9]\n"
     "b 67f\n"
     "66:"  // Oddments: Load (3, 5): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 67f\n"
-    "ld1 { v22.b }[0], [x0]\n"
+    "tbz x0, #0, 67f\n"
+    "ld1 { v22.b }[0], [x9]\n"
     "67:"  // Oddments: Load (3, 5): Bit 2: End
+    "ldr d0, [x23, #0x78]\n"
     "usubl v22.8h, v22.8b, v9.8b\n"
-    "ldr d0, [x3, #0x78]\n"
-    "smlal v8.4s, v22.4h, v4.4h\n"
-    "ldr x15, [x25, #0xc0]\n"
-    "add x15, x15, x10\n"
-    "smlal2 v5.4s, v22.8h, v4.8h\n"
     "ssubl v0.8h, v0.8b, v14.8b\n"
+    "ldr x19, [x20, #0xc0]\n"
+    "smlal v6.4s, v22.4h, v4.4h\n"
+    "smlal2 v5.4s, v22.8h, v4.8h\n"
+    "add x19, x19, x24\n"
     "smlal v15.4s, v31.4h, v0.4h\n"
-    "smlal2 v18.4s, v31.8h, v0.8h\n"
-    "smlal v16.4s, v30.4h, v0.4h\n"
-    "smlal2 v21.4s, v30.8h, v0.8h\n"
-    "tbz x4, #2, 69f\n"
-    "ld1 { v27.s }[0], [x15], #0x4\n"
-    "tbz x4, #1, 68f\n"
-    "ld1 { v27.h }[2], [x15], #0x2\n"
-    "tbz x4, #0, 71f\n"
-    "ld1 { v27.b }[6], [x15]\n"
+    "smlal2 v16.4s, v31.8h, v0.8h\n"
+    "smlal v17.4s, v30.4h, v0.4h\n"
+    "smlal2 v8.4s, v30.8h, v0.8h\n"
+    "tbz x0, #2, 69f\n"
+    "ld1 { v27.s }[0], [x19], #0x4\n"
+    "tbz x0, #1, 68f\n"
+    "ld1 { v27.h }[2], [x19], #0x2\n"
+    "tbz x0, #0, 71f\n"
+    "ld1 { v27.b }[6], [x19]\n"
     "b 71f\n"
     "68:"  // Oddments: Load (4, 0): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 71f\n"
-    "ld1 { v27.b }[4], [x15]\n"
+    "tbz x0, #0, 71f\n"
+    "ld1 { v27.b }[4], [x19]\n"
     "b 71f\n"
     "69:"  // Oddments: Load (4, 0): Bit 2: Unset
-    "tbz x4, #1, 70f\n"
-    "ld1 { v27.h }[0], [x15], #0x2\n"
-    "tbz x4, #0, 71f\n"
-    "ld1 { v27.b }[2], [x15]\n"
+    "tbz x0, #1, 70f\n"
+    "ld1 { v27.h }[0], [x19], #0x2\n"
+    "tbz x0, #0, 71f\n"
+    "ld1 { v27.b }[2], [x19]\n"
     "b 71f\n"
     "70:"  // Oddments: Load (4, 0): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 71f\n"
-    "ld1 { v27.b }[0], [x15]\n"
+    "tbz x0, #0, 71f\n"
+    "ld1 { v27.b }[0], [x19]\n"
     "71:"  // Oddments: Load (4, 0): Bit 2: End
     "usubl v27.8h, v27.8b, v9.8b\n"
-    "ldr x9, [x25, #0xc8]\n"
-    "smlal v7.4s, v27.4h, v0.4h\n"
-    "add x9, x9, x10\n"
-    "smlal2 v17.4s, v27.8h, v0.8h\n"
-    "tbz x4, #2, 73f\n"
-    "ld1 { v23.s }[0], [x9], #0x4\n"
-    "tbz x4, #1, 72f\n"
-    "ld1 { v23.h }[2], [x9], #0x2\n"
-    "tbz x4, #0, 75f\n"
-    "ld1 { v23.b }[6], [x9]\n"
+    "ldr x28, [x20, #0xc8]\n"
+    "smlal v10.4s, v27.4h, v0.4h\n"
+    "smlal2 v7.4s, v27.8h, v0.8h\n"
+    "add x28, x28, x24\n"
+    "tbz x0, #2, 73f\n"
+    "ld1 { v23.s }[0], [x28], #0x4\n"
+    "tbz x0, #1, 72f\n"
+    "ld1 { v23.h }[2], [x28], #0x2\n"
+    "tbz x0, #0, 75f\n"
+    "ld1 { v23.b }[6], [x28]\n"
     "b 75f\n"
     "72:"  // Oddments: Load (4, 1): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 75f\n"
-    "ld1 { v23.b }[4], [x9]\n"
+    "tbz x0, #0, 75f\n"
+    "ld1 { v23.b }[4], [x28]\n"
     "b 75f\n"
     "73:"  // Oddments: Load (4, 1): Bit 2: Unset
-    "tbz x4, #1, 74f\n"
-    "ld1 { v23.h }[0], [x9], #0x2\n"
-    "tbz x4, #0, 75f\n"
-    "ld1 { v23.b }[2], [x9]\n"
+    "tbz x0, #1, 74f\n"
+    "ld1 { v23.h }[0], [x28], #0x2\n"
+    "tbz x0, #0, 75f\n"
+    "ld1 { v23.b }[2], [x28]\n"
     "b 75f\n"
     "74:"  // Oddments: Load (4, 1): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 75f\n"
-    "ld1 { v23.b }[0], [x9]\n"
+    "tbz x0, #0, 75f\n"
+    "ld1 { v23.b }[0], [x28]\n"
     "75:"  // Oddments: Load (4, 1): Bit 2: End
+    "ldr d1, [x23, #0x80]\n"
     "usubl v23.8h, v23.8b, v9.8b\n"
-    "ldr d1, [x3, #0x80]\n"
-    "smlal v8.4s, v23.4h, v0.4h\n"
-    "ldr x27, [x25, #0xd0]\n"
-    "add x27, x27, x10\n"
-    "smlal2 v5.4s, v23.8h, v0.8h\n"
     "ssubl v1.8h, v1.8b, v14.8b\n"
+    "ldr x6, [x20, #0xd0]\n"
+    "smlal v6.4s, v23.4h, v0.4h\n"
+    "smlal2 v5.4s, v23.8h, v0.8h\n"
+    "add x6, x6, x24\n"
     "smlal v15.4s, v30.4h, v1.4h\n"
-    "smlal2 v18.4s, v30.8h, v1.8h\n"
-    "smlal v16.4s, v26.4h, v1.4h\n"
-    "smlal2 v21.4s, v26.8h, v1.8h\n"
-    "smlal v7.4s, v23.4h, v1.4h\n"
-    "smlal2 v17.4s, v23.8h, v1.8h\n"
-    "tbz x4, #2, 77f\n"
-    "ld1 { v31.s }[0], [x27], #0x4\n"
-    "tbz x4, #1, 76f\n"
-    "ld1 { v31.h }[2], [x27], #0x2\n"
-    "tbz x4, #0, 79f\n"
-    "ld1 { v31.b }[6], [x27]\n"
+    "smlal2 v16.4s, v30.8h, v1.8h\n"
+    "smlal v17.4s, v26.4h, v1.4h\n"
+    "smlal2 v8.4s, v26.8h, v1.8h\n"
+    "smlal v10.4s, v23.4h, v1.4h\n"
+    "smlal2 v7.4s, v23.8h, v1.8h\n"
+    "tbz x0, #2, 77f\n"
+    "ld1 { v31.s }[0], [x6], #0x4\n"
+    "tbz x0, #1, 76f\n"
+    "ld1 { v31.h }[2], [x6], #0x2\n"
+    "tbz x0, #0, 79f\n"
+    "ld1 { v31.b }[6], [x6]\n"
     "b 79f\n"
     "76:"  // Oddments: Load (4, 2): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 79f\n"
-    "ld1 { v31.b }[4], [x27]\n"
+    "tbz x0, #0, 79f\n"
+    "ld1 { v31.b }[4], [x6]\n"
     "b 79f\n"
     "77:"  // Oddments: Load (4, 2): Bit 2: Unset
-    "tbz x4, #1, 78f\n"
-    "ld1 { v31.h }[0], [x27], #0x2\n"
-    "tbz x4, #0, 79f\n"
-    "ld1 { v31.b }[2], [x27]\n"
+    "tbz x0, #1, 78f\n"
+    "ld1 { v31.h }[0], [x6], #0x2\n"
+    "tbz x0, #0, 79f\n"
+    "ld1 { v31.b }[2], [x6]\n"
     "b 79f\n"
     "78:"  // Oddments: Load (4, 2): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 79f\n"
-    "ld1 { v31.b }[0], [x27]\n"
+    "tbz x0, #0, 79f\n"
+    "ld1 { v31.b }[0], [x6]\n"
     "79:"  // Oddments: Load (4, 2): Bit 2: End
+    "ldr d2, [x23, #0x88]\n"
     "usubl v31.8h, v31.8b, v9.8b\n"
-    "ldr d2, [x3, #0x88]\n"
-    "smlal v8.4s, v31.4h, v1.4h\n"
-    "ldr x28, [x25, #0xd8]\n"
-    "add x28, x28, x10\n"
-    "smlal2 v5.4s, v31.8h, v1.8h\n"
     "ssubl v2.8h, v2.8b, v14.8b\n"
+    "ldr x27, [x20, #0xd8]\n"
+    "smlal v6.4s, v31.4h, v1.4h\n"
+    "smlal2 v5.4s, v31.8h, v1.8h\n"
+    "add x27, x27, x24\n"
     "smlal v15.4s, v26.4h, v2.4h\n"
-    "smlal2 v18.4s, v26.8h, v2.8h\n"
-    "smlal v16.4s, v25.4h, v2.4h\n"
-    "smlal2 v21.4s, v25.8h, v2.8h\n"
-    "smlal v7.4s, v31.4h, v2.4h\n"
-    "smlal2 v17.4s, v31.8h, v2.8h\n"
-    "tbz x4, #2, 81f\n"
-    "ld1 { v30.s }[0], [x28], #0x4\n"
-    "tbz x4, #1, 80f\n"
-    "ld1 { v30.h }[2], [x28], #0x2\n"
-    "tbz x4, #0, 83f\n"
-    "ld1 { v30.b }[6], [x28]\n"
+    "smlal2 v16.4s, v26.8h, v2.8h\n"
+    "smlal v17.4s, v25.4h, v2.4h\n"
+    "smlal2 v8.4s, v25.8h, v2.8h\n"
+    "smlal v10.4s, v31.4h, v2.4h\n"
+    "smlal2 v7.4s, v31.8h, v2.8h\n"
+    "tbz x0, #2, 81f\n"
+    "ld1 { v30.s }[0], [x27], #0x4\n"
+    "tbz x0, #1, 80f\n"
+    "ld1 { v30.h }[2], [x27], #0x2\n"
+    "tbz x0, #0, 83f\n"
+    "ld1 { v30.b }[6], [x27]\n"
     "b 83f\n"
     "80:"  // Oddments: Load (4, 3): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 83f\n"
-    "ld1 { v30.b }[4], [x28]\n"
+    "tbz x0, #0, 83f\n"
+    "ld1 { v30.b }[4], [x27]\n"
     "b 83f\n"
     "81:"  // Oddments: Load (4, 3): Bit 2: Unset
-    "tbz x4, #1, 82f\n"
-    "ld1 { v30.h }[0], [x28], #0x2\n"
-    "tbz x4, #0, 83f\n"
-    "ld1 { v30.b }[2], [x28]\n"
+    "tbz x0, #1, 82f\n"
+    "ld1 { v30.h }[0], [x27], #0x2\n"
+    "tbz x0, #0, 83f\n"
+    "ld1 { v30.b }[2], [x27]\n"
     "b 83f\n"
     "82:"  // Oddments: Load (4, 3): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 83f\n"
-    "ld1 { v30.b }[0], [x28]\n"
+    "tbz x0, #0, 83f\n"
+    "ld1 { v30.b }[0], [x27]\n"
     "83:"  // Oddments: Load (4, 3): Bit 2: End
+    "ldr d3, [x23, #0x90]\n"
     "usubl v30.8h, v30.8b, v9.8b\n"
-    "ldr d3, [x3, #0x90]\n"
-    "smlal v8.4s, v30.4h, v2.4h\n"
-    "ldr x12, [x25, #0xe0]\n"
-    "add x12, x12, x10\n"
-    "smlal2 v5.4s, v30.8h, v2.8h\n"
     "ssubl v3.8h, v3.8b, v14.8b\n"
+    "ldr x11, [x20, #0xe0]\n"
+    "smlal v6.4s, v30.4h, v2.4h\n"
+    "smlal2 v5.4s, v30.8h, v2.8h\n"
+    "add x11, x11, x24\n"
     "smlal v15.4s, v25.4h, v3.4h\n"
-    "smlal2 v18.4s, v25.8h, v3.8h\n"
-    "smlal v16.4s, v24.4h, v3.4h\n"
-    "smlal2 v21.4s, v24.8h, v3.8h\n"
-    "smlal v7.4s, v30.4h, v3.4h\n"
-    "smlal2 v17.4s, v30.8h, v3.8h\n"
-    "tbz x4, #2, 85f\n"
-    "ld1 { v28.s }[0], [x12], #0x4\n"
-    "tbz x4, #1, 84f\n"
-    "ld1 { v28.h }[2], [x12], #0x2\n"
-    "tbz x4, #0, 87f\n"
-    "ld1 { v28.b }[6], [x12]\n"
+    "smlal2 v16.4s, v25.8h, v3.8h\n"
+    "smlal v17.4s, v24.4h, v3.4h\n"
+    "smlal2 v8.4s, v24.8h, v3.8h\n"
+    "smlal v10.4s, v30.4h, v3.4h\n"
+    "smlal2 v7.4s, v30.8h, v3.8h\n"
+    "tbz x0, #2, 85f\n"
+    "ld1 { v28.s }[0], [x11], #0x4\n"
+    "tbz x0, #1, 84f\n"
+    "ld1 { v28.h }[2], [x11], #0x2\n"
+    "tbz x0, #0, 87f\n"
+    "ld1 { v28.b }[6], [x11]\n"
     "b 87f\n"
     "84:"  // Oddments: Load (4, 4): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 87f\n"
-    "ld1 { v28.b }[4], [x12]\n"
+    "tbz x0, #0, 87f\n"
+    "ld1 { v28.b }[4], [x11]\n"
     "b 87f\n"
     "85:"  // Oddments: Load (4, 4): Bit 2: Unset
-    "tbz x4, #1, 86f\n"
-    "ld1 { v28.h }[0], [x12], #0x2\n"
-    "tbz x4, #0, 87f\n"
-    "ld1 { v28.b }[2], [x12]\n"
+    "tbz x0, #1, 86f\n"
+    "ld1 { v28.h }[0], [x11], #0x2\n"
+    "tbz x0, #0, 87f\n"
+    "ld1 { v28.b }[2], [x11]\n"
     "b 87f\n"
     "86:"  // Oddments: Load (4, 4): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 87f\n"
-    "ld1 { v28.b }[0], [x12]\n"
+    "tbz x0, #0, 87f\n"
+    "ld1 { v28.b }[0], [x11]\n"
     "87:"  // Oddments: Load (4, 4): Bit 2: End
+    "ldr d4, [x23, #0x98]\n"
     "usubl v28.8h, v28.8b, v9.8b\n"
-    "ldr d4, [x3, #0x98]\n"
-    "smlal v8.4s, v28.4h, v3.4h\n"
-    "ldr x7, [x25, #0xe8]\n"
-    "add x7, x7, x10\n"
-    "smlal2 v5.4s, v28.8h, v3.8h\n"
     "ssubl v4.8h, v4.8b, v14.8b\n"
+    "ldr x17, [x20, #0xe8]\n"
+    "smlal v6.4s, v28.4h, v3.4h\n"
+    "smlal2 v5.4s, v28.8h, v3.8h\n"
+    "add x17, x17, x24\n"
     "smlal v15.4s, v24.4h, v4.4h\n"
-    "smlal2 v18.4s, v24.8h, v4.8h\n"
-    "smlal v16.4s, v22.4h, v4.4h\n"
-    "smlal2 v21.4s, v22.8h, v4.8h\n"
-    "smlal v7.4s, v28.4h, v4.4h\n"
-    "smlal2 v17.4s, v28.8h, v4.8h\n"
-    "tbz x4, #2, 89f\n"
-    "ld1 { v26.s }[0], [x7], #0x4\n"
-    "tbz x4, #1, 88f\n"
-    "ld1 { v26.h }[2], [x7], #0x2\n"
-    "tbz x4, #0, 91f\n"
-    "ld1 { v26.b }[6], [x7]\n"
+    "smlal2 v16.4s, v24.8h, v4.8h\n"
+    "smlal v17.4s, v22.4h, v4.4h\n"
+    "smlal2 v8.4s, v22.8h, v4.8h\n"
+    "smlal v10.4s, v28.4h, v4.4h\n"
+    "smlal2 v7.4s, v28.8h, v4.8h\n"
+    "tbz x0, #2, 89f\n"
+    "ld1 { v26.s }[0], [x17], #0x4\n"
+    "tbz x0, #1, 88f\n"
+    "ld1 { v26.h }[2], [x17], #0x2\n"
+    "tbz x0, #0, 91f\n"
+    "ld1 { v26.b }[6], [x17]\n"
     "b 91f\n"
     "88:"  // Oddments: Load (4, 5): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 91f\n"
-    "ld1 { v26.b }[4], [x7]\n"
+    "tbz x0, #0, 91f\n"
+    "ld1 { v26.b }[4], [x17]\n"
     "b 91f\n"
     "89:"  // Oddments: Load (4, 5): Bit 2: Unset
-    "tbz x4, #1, 90f\n"
-    "ld1 { v26.h }[0], [x7], #0x2\n"
-    "tbz x4, #0, 91f\n"
-    "ld1 { v26.b }[2], [x7]\n"
+    "tbz x0, #1, 90f\n"
+    "ld1 { v26.h }[0], [x17], #0x2\n"
+    "tbz x0, #0, 91f\n"
+    "ld1 { v26.b }[2], [x17]\n"
     "b 91f\n"
     "90:"  // Oddments: Load (4, 5): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 91f\n"
-    "ld1 { v26.b }[0], [x7]\n"
+    "tbz x0, #0, 91f\n"
+    "ld1 { v26.b }[0], [x17]\n"
     "91:"  // Oddments: Load (4, 5): Bit 2: End
+    "ldr d0, [x23, #0xa0]\n"
     "usubl v26.8h, v26.8b, v9.8b\n"
-    "ldr d0, [x3, #0xa0]\n"
-    "smlal v8.4s, v26.4h, v4.4h\n"
-    "ldr x26, [x25, #0xf0]\n"
-    "add x26, x26, x10\n"
-    "smlal2 v5.4s, v26.8h, v4.8h\n"
     "ssubl v0.8h, v0.8b, v14.8b\n"
+    "ldr x5, [x20, #0xf0]\n"
+    "smlal v6.4s, v26.4h, v4.4h\n"
+    "smlal2 v5.4s, v26.8h, v4.8h\n"
+    "add x5, x5, x24\n"
     "smlal v15.4s, v27.4h, v0.4h\n"
-    "smlal2 v18.4s, v27.8h, v0.8h\n"
-    "smlal v16.4s, v23.4h, v0.4h\n"
-    "smlal2 v21.4s, v23.8h, v0.8h\n"
-    "tbz x4, #2, 93f\n"
-    "ld1 { v25.s }[0], [x26], #0x4\n"
-    "tbz x4, #1, 92f\n"
-    "ld1 { v25.h }[2], [x26], #0x2\n"
-    "tbz x4, #0, 95f\n"
-    "ld1 { v25.b }[6], [x26]\n"
+    "smlal2 v16.4s, v27.8h, v0.8h\n"
+    "smlal v17.4s, v23.4h, v0.4h\n"
+    "smlal2 v8.4s, v23.8h, v0.8h\n"
+    "tbz x0, #2, 93f\n"
+    "ld1 { v25.s }[0], [x5], #0x4\n"
+    "tbz x0, #1, 92f\n"
+    "ld1 { v25.h }[2], [x5], #0x2\n"
+    "tbz x0, #0, 95f\n"
+    "ld1 { v25.b }[6], [x5]\n"
     "b 95f\n"
     "92:"  // Oddments: Load (5, 0): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 95f\n"
-    "ld1 { v25.b }[4], [x26]\n"
+    "tbz x0, #0, 95f\n"
+    "ld1 { v25.b }[4], [x5]\n"
     "b 95f\n"
     "93:"  // Oddments: Load (5, 0): Bit 2: Unset
-    "tbz x4, #1, 94f\n"
-    "ld1 { v25.h }[0], [x26], #0x2\n"
-    "tbz x4, #0, 95f\n"
-    "ld1 { v25.b }[2], [x26]\n"
+    "tbz x0, #1, 94f\n"
+    "ld1 { v25.h }[0], [x5], #0x2\n"
+    "tbz x0, #0, 95f\n"
+    "ld1 { v25.b }[2], [x5]\n"
     "b 95f\n"
     "94:"  // Oddments: Load (5, 0): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 95f\n"
-    "ld1 { v25.b }[0], [x26]\n"
+    "tbz x0, #0, 95f\n"
+    "ld1 { v25.b }[0], [x5]\n"
     "95:"  // Oddments: Load (5, 0): Bit 2: End
     "usubl v25.8h, v25.8b, v9.8b\n"
-    "ldr x23, [x25, #0xf8]\n"
-    "smlal v7.4s, v25.4h, v0.4h\n"
-    "add x23, x23, x10\n"
-    "smlal2 v17.4s, v25.8h, v0.8h\n"
-    "tbz x4, #2, 97f\n"
-    "ld1 { v24.s }[0], [x23], #0x4\n"
-    "tbz x4, #1, 96f\n"
-    "ld1 { v24.h }[2], [x23], #0x2\n"
-    "tbz x4, #0, 99f\n"
-    "ld1 { v24.b }[6], [x23]\n"
+    "ldr x25, [x20, #0xf8]\n"
+    "smlal v10.4s, v25.4h, v0.4h\n"
+    "smlal2 v7.4s, v25.8h, v0.8h\n"
+    "add x25, x25, x24\n"
+    "tbz x0, #2, 97f\n"
+    "ld1 { v24.s }[0], [x25], #0x4\n"
+    "tbz x0, #1, 96f\n"
+    "ld1 { v24.h }[2], [x25], #0x2\n"
+    "tbz x0, #0, 99f\n"
+    "ld1 { v24.b }[6], [x25]\n"
     "b 99f\n"
     "96:"  // Oddments: Load (5, 1): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 99f\n"
-    "ld1 { v24.b }[4], [x23]\n"
+    "tbz x0, #0, 99f\n"
+    "ld1 { v24.b }[4], [x25]\n"
     "b 99f\n"
     "97:"  // Oddments: Load (5, 1): Bit 2: Unset
-    "tbz x4, #1, 98f\n"
-    "ld1 { v24.h }[0], [x23], #0x2\n"
-    "tbz x4, #0, 99f\n"
-    "ld1 { v24.b }[2], [x23]\n"
+    "tbz x0, #1, 98f\n"
+    "ld1 { v24.h }[0], [x25], #0x2\n"
+    "tbz x0, #0, 99f\n"
+    "ld1 { v24.b }[2], [x25]\n"
     "b 99f\n"
     "98:"  // Oddments: Load (5, 1): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 99f\n"
-    "ld1 { v24.b }[0], [x23]\n"
+    "tbz x0, #0, 99f\n"
+    "ld1 { v24.b }[0], [x25]\n"
     "99:"  // Oddments: Load (5, 1): Bit 2: End
+    "ldr d1, [x23, #0xa8]\n"
     "usubl v24.8h, v24.8b, v9.8b\n"
-    "ldr d1, [x3, #0xa8]\n"
-    "smlal v8.4s, v24.4h, v0.4h\n"
-    "ldr x22, [x25, #0x100]\n"
-    "add x22, x22, x10\n"
-    "smlal2 v5.4s, v24.8h, v0.8h\n"
     "ssubl v1.8h, v1.8b, v14.8b\n"
+    "ldr x26, [x20, #0x100]\n"
+    "smlal v6.4s, v24.4h, v0.4h\n"
+    "smlal2 v5.4s, v24.8h, v0.8h\n"
+    "add x26, x26, x24\n"
     "smlal v15.4s, v23.4h, v1.4h\n"
-    "smlal2 v18.4s, v23.8h, v1.8h\n"
-    "smlal v16.4s, v31.4h, v1.4h\n"
-    "smlal2 v21.4s, v31.8h, v1.8h\n"
-    "smlal v7.4s, v24.4h, v1.4h\n"
-    "smlal2 v17.4s, v24.8h, v1.8h\n"
-    "tbz x4, #2, 101f\n"
-    "ld1 { v27.s }[0], [x22], #0x4\n"
-    "tbz x4, #1, 100f\n"
-    "ld1 { v27.h }[2], [x22], #0x2\n"
-    "tbz x4, #0, 103f\n"
-    "ld1 { v27.b }[6], [x22]\n"
+    "smlal2 v16.4s, v23.8h, v1.8h\n"
+    "smlal v17.4s, v31.4h, v1.4h\n"
+    "smlal2 v8.4s, v31.8h, v1.8h\n"
+    "smlal v10.4s, v24.4h, v1.4h\n"
+    "smlal2 v7.4s, v24.8h, v1.8h\n"
+    "tbz x0, #2, 101f\n"
+    "ld1 { v27.s }[0], [x26], #0x4\n"
+    "tbz x0, #1, 100f\n"
+    "ld1 { v27.h }[2], [x26], #0x2\n"
+    "tbz x0, #0, 103f\n"
+    "ld1 { v27.b }[6], [x26]\n"
     "b 103f\n"
     "100:"  // Oddments: Load (5, 2): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 103f\n"
-    "ld1 { v27.b }[4], [x22]\n"
+    "tbz x0, #0, 103f\n"
+    "ld1 { v27.b }[4], [x26]\n"
     "b 103f\n"
     "101:"  // Oddments: Load (5, 2): Bit 2: Unset
-    "tbz x4, #1, 102f\n"
-    "ld1 { v27.h }[0], [x22], #0x2\n"
-    "tbz x4, #0, 103f\n"
-    "ld1 { v27.b }[2], [x22]\n"
+    "tbz x0, #1, 102f\n"
+    "ld1 { v27.h }[0], [x26], #0x2\n"
+    "tbz x0, #0, 103f\n"
+    "ld1 { v27.b }[2], [x26]\n"
     "b 103f\n"
     "102:"  // Oddments: Load (5, 2): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 103f\n"
-    "ld1 { v27.b }[0], [x22]\n"
+    "tbz x0, #0, 103f\n"
+    "ld1 { v27.b }[0], [x26]\n"
     "103:"  // Oddments: Load (5, 2): Bit 2: End
+    "ldr d2, [x23, #0xb0]\n"
     "usubl v27.8h, v27.8b, v9.8b\n"
-    "ldr d2, [x3, #0xb0]\n"
-    "smlal v8.4s, v27.4h, v1.4h\n"
-    "ldr x20, [x25, #0x108]\n"
-    "add x20, x20, x10\n"
-    "smlal2 v5.4s, v27.8h, v1.8h\n"
     "ssubl v2.8h, v2.8b, v14.8b\n"
+    "ldr x12, [x20, #0x108]\n"
+    "smlal v6.4s, v27.4h, v1.4h\n"
+    "smlal2 v5.4s, v27.8h, v1.8h\n"
+    "add x12, x12, x24\n"
     "smlal v15.4s, v31.4h, v2.4h\n"
-    "smlal2 v18.4s, v31.8h, v2.8h\n"
-    "smlal v16.4s, v30.4h, v2.4h\n"
-    "smlal2 v21.4s, v30.8h, v2.8h\n"
-    "smlal v7.4s, v27.4h, v2.4h\n"
-    "smlal2 v17.4s, v27.8h, v2.8h\n"
-    "tbz x4, #2, 105f\n"
-    "ld1 { v25.s }[0], [x20], #0x4\n"
-    "tbz x4, #1, 104f\n"
-    "ld1 { v25.h }[2], [x20], #0x2\n"
-    "tbz x4, #0, 107f\n"
-    "ld1 { v25.b }[6], [x20]\n"
+    "smlal2 v16.4s, v31.8h, v2.8h\n"
+    "smlal v17.4s, v30.4h, v2.4h\n"
+    "smlal2 v8.4s, v30.8h, v2.8h\n"
+    "smlal v10.4s, v27.4h, v2.4h\n"
+    "smlal2 v7.4s, v27.8h, v2.8h\n"
+    "tbz x0, #2, 105f\n"
+    "ld1 { v25.s }[0], [x12], #0x4\n"
+    "tbz x0, #1, 104f\n"
+    "ld1 { v25.h }[2], [x12], #0x2\n"
+    "tbz x0, #0, 107f\n"
+    "ld1 { v25.b }[6], [x12]\n"
     "b 107f\n"
     "104:"  // Oddments: Load (5, 3): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 107f\n"
-    "ld1 { v25.b }[4], [x20]\n"
+    "tbz x0, #0, 107f\n"
+    "ld1 { v25.b }[4], [x12]\n"
     "b 107f\n"
     "105:"  // Oddments: Load (5, 3): Bit 2: Unset
-    "tbz x4, #1, 106f\n"
-    "ld1 { v25.h }[0], [x20], #0x2\n"
-    "tbz x4, #0, 107f\n"
-    "ld1 { v25.b }[2], [x20]\n"
+    "tbz x0, #1, 106f\n"
+    "ld1 { v25.h }[0], [x12], #0x2\n"
+    "tbz x0, #0, 107f\n"
+    "ld1 { v25.b }[2], [x12]\n"
     "b 107f\n"
     "106:"  // Oddments: Load (5, 3): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 107f\n"
-    "ld1 { v25.b }[0], [x20]\n"
+    "tbz x0, #0, 107f\n"
+    "ld1 { v25.b }[0], [x12]\n"
     "107:"  // Oddments: Load (5, 3): Bit 2: End
+    "ldr d3, [x23, #0xb8]\n"
     "usubl v25.8h, v25.8b, v9.8b\n"
-    "ldr d3, [x3, #0xb8]\n"
-    "smlal v8.4s, v25.4h, v2.4h\n"
-    "ldr x13, [x25, #0x110]\n"
-    "add x13, x13, x10\n"
-    "smlal2 v5.4s, v25.8h, v2.8h\n"
     "ssubl v3.8h, v3.8b, v14.8b\n"
+    "ldr x14, [x20, #0x110]\n"
+    "smlal v6.4s, v25.4h, v2.4h\n"
+    "smlal2 v5.4s, v25.8h, v2.8h\n"
+    "add x14, x14, x24\n"
     "smlal v15.4s, v30.4h, v3.4h\n"
-    "smlal2 v18.4s, v30.8h, v3.8h\n"
-    "smlal v16.4s, v28.4h, v3.4h\n"
-    "smlal2 v21.4s, v28.8h, v3.8h\n"
-    "smlal v7.4s, v25.4h, v3.4h\n"
-    "smlal2 v17.4s, v25.8h, v3.8h\n"
-    "tbz x4, #2, 109f\n"
-    "ld1 { v24.s }[0], [x13], #0x4\n"
-    "tbz x4, #1, 108f\n"
-    "ld1 { v24.h }[2], [x13], #0x2\n"
-    "tbz x4, #0, 111f\n"
-    "ld1 { v24.b }[6], [x13]\n"
+    "smlal2 v16.4s, v30.8h, v3.8h\n"
+    "smlal v17.4s, v28.4h, v3.4h\n"
+    "smlal2 v8.4s, v28.8h, v3.8h\n"
+    "smlal v10.4s, v25.4h, v3.4h\n"
+    "smlal2 v7.4s, v25.8h, v3.8h\n"
+    "tbz x0, #2, 109f\n"
+    "ld1 { v24.s }[0], [x14], #0x4\n"
+    "tbz x0, #1, 108f\n"
+    "ld1 { v24.h }[2], [x14], #0x2\n"
+    "tbz x0, #0, 111f\n"
+    "ld1 { v24.b }[6], [x14]\n"
     "b 111f\n"
     "108:"  // Oddments: Load (5, 4): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 111f\n"
-    "ld1 { v24.b }[4], [x13]\n"
+    "tbz x0, #0, 111f\n"
+    "ld1 { v24.b }[4], [x14]\n"
     "b 111f\n"
     "109:"  // Oddments: Load (5, 4): Bit 2: Unset
-    "tbz x4, #1, 110f\n"
-    "ld1 { v24.h }[0], [x13], #0x2\n"
-    "tbz x4, #0, 111f\n"
-    "ld1 { v24.b }[2], [x13]\n"
+    "tbz x0, #1, 110f\n"
+    "ld1 { v24.h }[0], [x14], #0x2\n"
+    "tbz x0, #0, 111f\n"
+    "ld1 { v24.b }[2], [x14]\n"
     "b 111f\n"
     "110:"  // Oddments: Load (5, 4): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 111f\n"
-    "ld1 { v24.b }[0], [x13]\n"
+    "tbz x0, #0, 111f\n"
+    "ld1 { v24.b }[0], [x14]\n"
     "111:"  // Oddments: Load (5, 4): Bit 2: End
+    "ldr d4, [x23, #0xc0]\n"
     "usubl v24.8h, v24.8b, v9.8b\n"
-    "ldr d4, [x3, #0xc0]\n"
-    "smlal v8.4s, v24.4h, v3.4h\n"
-    "ldr x21, [x25, #0x118]\n"
-    "add x21, x21, x10\n"
-    "smlal2 v5.4s, v24.8h, v3.8h\n"
     "ssubl v4.8h, v4.8b, v14.8b\n"
+    "ldr x21, [x20, #0x118]\n"
+    "smlal v6.4s, v24.4h, v3.4h\n"
+    "smlal2 v5.4s, v24.8h, v3.8h\n"
+    "add x21, x21, x24\n"
     "smlal v15.4s, v28.4h, v4.4h\n"
-    "smlal2 v18.4s, v28.8h, v4.8h\n"
-    "smlal v16.4s, v26.4h, v4.4h\n"
-    "smlal2 v21.4s, v26.8h, v4.8h\n"
-    "smlal v7.4s, v24.4h, v4.4h\n"
-    "smlal2 v17.4s, v24.8h, v4.8h\n"
-    "tbz x4, #2, 113f\n"
+    "smlal2 v16.4s, v28.8h, v4.8h\n"
+    "smlal v17.4s, v26.4h, v4.4h\n"
+    "smlal2 v8.4s, v26.8h, v4.8h\n"
+    "smlal v10.4s, v24.4h, v4.4h\n"
+    "smlal2 v7.4s, v24.8h, v4.8h\n"
+    "tbz x0, #2, 113f\n"
     "ld1 { v27.s }[0], [x21], #0x4\n"
-    "tbz x4, #1, 112f\n"
+    "tbz x0, #1, 112f\n"
     "ld1 { v27.h }[2], [x21], #0x2\n"
-    "tbz x4, #0, 115f\n"
+    "tbz x0, #0, 115f\n"
     "ld1 { v27.b }[6], [x21]\n"
     "b 115f\n"
     "112:"  // Oddments: Load (5, 5): Bit 2: Bit 1: Unset
-    "tbz x4, #0, 115f\n"
+    "tbz x0, #0, 115f\n"
     "ld1 { v27.b }[4], [x21]\n"
     "b 115f\n"
     "113:"  // Oddments: Load (5, 5): Bit 2: Unset
-    "tbz x4, #1, 114f\n"
+    "tbz x0, #1, 114f\n"
     "ld1 { v27.h }[0], [x21], #0x2\n"
-    "tbz x4, #0, 115f\n"
+    "tbz x0, #0, 115f\n"
     "ld1 { v27.b }[2], [x21]\n"
     "b 115f\n"
     "114:"  // Oddments: Load (5, 5): Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 115f\n"
+    "tbz x0, #0, 115f\n"
     "ld1 { v27.b }[0], [x21]\n"
     "115:"  // Oddments: Load (5, 5): Bit 2: End
     "usubl v27.8h, v27.8b, v9.8b\n"
-    "smlal v8.4s, v27.4h, v4.4h\n"
+    "smlal v6.4s, v27.4h, v4.4h\n"
     "smlal2 v5.4s, v27.8h, v4.8h\n"
-    "tbz x4, #2, 117f\n"
-    "ld1 { v6.4s }, [x2], #0x10\n"
-    "ld1 { v19.4s }, [x5], #0x10\n"
-    "tbz x4, #1, 116f\n"
-    "ld1 { v20.d }[0], [x2], #0x8\n"
-    "ld1 { v12.d }[0], [x5], #0x8\n"
-    "tbz x4, #0, 119f\n"
-    "ld1 { v20.s }[2], [x2]\n"
-    "ld1 { v12.s }[2], [x5]\n"
+    "tbz x0, #2, 117f\n"
+    "ld1 { v12.4s }, [x10], #0x10\n"
+    "ld1 { v19.4s }, [x1], #0x10\n"
+    "tbz x0, #1, 116f\n"
+    "ld1 { v20.d }[0], [x10], #0x8\n"
+    "ld1 { v29.d }[0], [x1], #0x8\n"
+    "tbz x0, #0, 119f\n"
+    "ld1 { v20.s }[2], [x10]\n"
+    "ld1 { v29.s }[2], [x1]\n"
     "b 119f\n"
     "116:"  // Oddments: Load requant params: Bit 2: Bit 1: Unset
-    "tbz x4, #0, 119f\n"
-    "ld1 { v20.s }[0], [x2]\n"
-    "ld1 { v12.s }[0], [x5]\n"
+    "tbz x0, #0, 119f\n"
+    "ld1 { v20.s }[0], [x10]\n"
+    "ld1 { v29.s }[0], [x1]\n"
     "b 119f\n"
     "117:"  // Oddments: Load requant params: Bit 2: Unset
-    "tbz x4, #1, 118f\n"
-    "ld1 { v6.d }[0], [x2], #0x8\n"
-    "ld1 { v19.d }[0], [x5], #0x8\n"
-    "tbz x4, #0, 119f\n"
-    "ld1 { v6.s }[2], [x2]\n"
-    "ld1 { v19.s }[2], [x5]\n"
+    "tbz x0, #1, 118f\n"
+    "ld1 { v12.d }[0], [x10], #0x8\n"
+    "ld1 { v19.d }[0], [x1], #0x8\n"
+    "tbz x0, #0, 119f\n"
+    "ld1 { v12.s }[2], [x10]\n"
+    "ld1 { v19.s }[2], [x1]\n"
     "b 119f\n"
     "118:"  // Oddments: Load requant params: Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 119f\n"
-    "ld1 { v6.s }[0], [x2]\n"
-    "ld1 { v19.s }[0], [x5]\n"
+    "tbz x0, #0, 119f\n"
+    "ld1 { v12.s }[0], [x10]\n"
+    "ld1 { v19.s }[0], [x1]\n"
     "119:"  // Oddments: Load requant params: Bit 2: End
-    "sqrdmulh v15.4s, v15.4s, v6.4s\n"
-    "add x17, x17, x1\n"
-    "sqrdmulh v18.4s, v18.4s, v20.4s\n"
-    "add x16, x16, x1\n"
-    "sqrdmulh v16.4s, v16.4s, v6.4s\n"
-    "add x6, x6, x1\n"
-    "sqrdmulh v21.4s, v21.4s, v20.4s\n"
-    "add x8, x8, x1\n"
-    "sqrdmulh v7.4s, v7.4s, v6.4s\n"
-    "and v28.16b, v15.16b, v19.16b\n"
-    "and v26.16b, v18.16b, v12.16b\n"
-    "and v29.16b, v16.16b, v19.16b\n"
-    "sshr v28.4s, v28.4s, #0x1f\n"
+    "sqdmulh v15.4s, v15.4s, v12.4s\n"
+    "sqdmulh v17.4s, v17.4s, v12.4s\n"
+    "add x16, x16, x22\n"
+    "add x8, x8, x22\n"
+    "sqdmulh v10.4s, v10.4s, v12.4s\n"
+    "sqdmulh v6.4s, v6.4s, v12.4s\n"
+    "add x4, x4, x22\n"
+    "add x7, x7, x22\n"
+    "and v23.16b, v15.16b, v19.16b\n"
+    "sqdmulh v16.4s, v16.4s, v20.4s\n"
+    "and v22.16b, v17.16b, v19.16b\n"
+    "sqdmulh v8.4s, v8.4s, v20.4s\n"
+    "and v21.16b, v10.16b, v19.16b\n"
+    "sqdmulh v7.4s, v7.4s, v20.4s\n"
+    "and v26.16b, v6.16b, v19.16b\n"
+    "sqdmulh v5.4s, v5.4s, v20.4s\n"
+    "sshr v23.4s, v23.4s, #0x1f\n"
+    "and v4.16b, v16.16b, v29.16b\n"
+    "sshr v22.4s, v22.4s, #0x1f\n"
+    "and v2.16b, v8.16b, v29.16b\n"
+    "sshr v21.4s, v21.4s, #0x1f\n"
+    "and v3.16b, v7.16b, v29.16b\n"
     "sshr v26.4s, v26.4s, #0x1f\n"
-    "sshr v29.4s, v29.4s, #0x1f\n"
-    "sqadd v15.4s, v15.4s, v28.4s\n"
-    "sqadd v18.4s, v18.4s, v26.4s\n"
-    "sqadd v16.4s, v16.4s, v29.4s\n"
-    "and v4.16b, v21.16b, v12.16b\n"
-    "srshl v15.4s, v15.4s, v19.4s\n"
-    "srshl v18.4s, v18.4s, v12.4s\n"
-    "srshl v16.4s, v16.4s, v19.4s\n"
+    "and v25.16b, v5.16b, v29.16b\n"
+    "sqadd v15.4s, v15.4s, v23.4s\n"
     "sshr v4.4s, v4.4s, #0x1f\n"
-    "add v15.4s, v15.4s, v10.4s\n"
-    "add v18.4s, v18.4s, v10.4s\n"
-    "add v16.4s, v16.4s, v10.4s\n"
-    "smin v15.4s, v15.4s, v13.4s\n"
-    "smin v18.4s, v18.4s, v13.4s\n"
-    "smin v16.4s, v16.4s, v13.4s\n"
-    "smax v15.4s, v15.4s, v11.4s\n"
-    "smax v18.4s, v18.4s, v11.4s\n"
-    "smax v16.4s, v16.4s, v11.4s\n"
-    "sqadd v21.4s, v21.4s, v4.4s\n"
-    "uzp1 v15.16b, v15.16b, v18.16b\n"
-    "and v25.16b, v7.16b, v19.16b\n"
-    "uzp1 v15.16b, v15.16b, v15.16b\n"
-    "srshl v21.4s, v21.4s, v12.4s\n"
+    "sqadd v17.4s, v17.4s, v22.4s\n"
+    "sshr v2.4s, v2.4s, #0x1f\n"
+    "sqadd v10.4s, v10.4s, v21.4s\n"
+    "sshr v3.4s, v3.4s, #0x1f\n"
+    "sqadd v6.4s, v6.4s, v26.4s\n"
     "sshr v25.4s, v25.4s, #0x1f\n"
-    "sqrdmulh v17.4s, v17.4s, v20.4s\n"
-    "sqrdmulh v8.4s, v8.4s, v6.4s\n"
-    "add v21.4s, v21.4s, v10.4s\n"
-    "sqadd v7.4s, v7.4s, v25.4s\n"
-    "and v31.16b, v17.16b, v12.16b\n"
-    "smin v21.4s, v21.4s, v13.4s\n"
-    "and v24.16b, v8.16b, v19.16b\n"
-    "srshl v7.4s, v7.4s, v19.4s\n"
-    "smax v21.4s, v21.4s, v11.4s\n"
-    "sshr v31.4s, v31.4s, #0x1f\n"
-    "sshr v24.4s, v24.4s, #0x1f\n"
-    "uzp1 v16.16b, v16.16b, v21.16b\n"
-    "add v7.4s, v7.4s, v10.4s\n"
-    "uzp1 v16.16b, v16.16b, v16.16b\n"
-    "sqadd v17.4s, v17.4s, v31.4s\n"
-    "smin v7.4s, v7.4s, v13.4s\n"
-    "sqadd v8.4s, v8.4s, v24.4s\n"
-    "sqrdmulh v5.4s, v5.4s, v20.4s\n"
-    "smax v7.4s, v7.4s, v11.4s\n"
-    "srshl v17.4s, v17.4s, v12.4s\n"
-    "srshl v8.4s, v8.4s, v19.4s\n"
-    "and v1.16b, v5.16b, v12.16b\n"
-    "add v17.4s, v17.4s, v10.4s\n"
-    "add v8.4s, v8.4s, v10.4s\n"
-    "sshr v1.4s, v1.4s, #0x1f\n"
-    "smin v17.4s, v17.4s, v13.4s\n"
-    "smin v8.4s, v8.4s, v13.4s\n"
-    "sqadd v5.4s, v5.4s, v1.4s\n"
-    "smax v17.4s, v17.4s, v11.4s\n"
-    "smax v8.4s, v8.4s, v11.4s\n"
-    "srshl v5.4s, v5.4s, v12.4s\n"
-    "uzp1 v7.16b, v7.16b, v17.16b\n"
-    "uzp1 v7.16b, v7.16b, v7.16b\n"
-    "add v5.4s, v5.4s, v10.4s\n"
-    "smin v5.4s, v5.4s, v13.4s\n"
-    "smax v5.4s, v5.4s, v11.4s\n"
-    "uzp1 v8.16b, v8.16b, v5.16b\n"
-    "uzp1 v8.16b, v8.16b, v8.16b\n"
-    "tbz x4, #2, 121f\n"
-    "st1 { v15.s }[0], [x17], #0x4\n"
-    "st1 { v16.s }[0], [x16], #0x4\n"
-    "st1 { v7.s }[0], [x6], #0x4\n"
-    "st1 { v8.s }[0], [x8], #0x4\n"
-    "tbz x4, #1, 120f\n"
-    "st1 { v15.h }[2], [x17], #0x2\n"
-    "st1 { v16.h }[2], [x16], #0x2\n"
-    "st1 { v7.h }[2], [x6], #0x2\n"
-    "st1 { v8.h }[2], [x8], #0x2\n"
-    "tbz x4, #0, 123f\n"
-    "st1 { v15.b }[6], [x17], #0x1\n"
-    "st1 { v16.b }[6], [x16], #0x1\n"
-    "st1 { v7.b }[6], [x6], #0x1\n"
-    "st1 { v8.b }[6], [x8], #0x1\n"
+    "srshl v15.4s, v15.4s, v19.4s\n"
+    "sqadd v16.4s, v16.4s, v4.4s\n"
+    "srshl v17.4s, v17.4s, v19.4s\n"
+    "sqadd v8.4s, v8.4s, v2.4s\n"
+    "srshl v10.4s, v10.4s, v19.4s\n"
+    "sqadd v7.4s, v7.4s, v3.4s\n"
+    "srshl v6.4s, v6.4s, v19.4s\n"
+    "sqadd v5.4s, v5.4s, v25.4s\n"
+    "srshl v16.4s, v16.4s, v29.4s\n"
+    "sqxtn v15.4h, v15.4s\n"
+    "srshl v8.4s, v8.4s, v29.4s\n"
+    "sqxtn v17.4h, v17.4s\n"
+    "srshl v7.4s, v7.4s, v29.4s\n"
+    "sqxtn v10.4h, v10.4s\n"
+    "srshl v5.4s, v5.4s, v29.4s\n"
+    "sqxtn v6.4h, v6.4s\n"
+    "sqxtn2 v15.8h, v16.4s\n"
+    "sqxtn2 v17.8h, v8.4s\n"
+    "sqxtn2 v10.8h, v7.4s\n"
+    "sqxtn2 v6.8h, v5.4s\n"
+    "sqadd v15.8h, v15.8h, v18.8h\n"
+    "sqadd v17.8h, v17.8h, v18.8h\n"
+    "sqadd v10.8h, v10.8h, v18.8h\n"
+    "sqadd v6.8h, v6.8h, v18.8h\n"
+    "smax v15.8h, v15.8h, v11.8h\n"
+    "smax v17.8h, v17.8h, v11.8h\n"
+    "smax v10.8h, v10.8h, v11.8h\n"
+    "smax v6.8h, v6.8h, v11.8h\n"
+    "smin v15.8h, v15.8h, v13.8h\n"
+    "smin v17.8h, v17.8h, v13.8h\n"
+    "smin v10.8h, v10.8h, v13.8h\n"
+    "smin v6.8h, v6.8h, v13.8h\n"
+    "uzp1 v15.16b, v15.16b, v15.16b\n"
+    "uzp1 v17.16b, v17.16b, v17.16b\n"
+    "uzp1 v10.16b, v10.16b, v10.16b\n"
+    "uzp1 v6.16b, v6.16b, v6.16b\n"
+    "tbz x0, #2, 121f\n"
+    "st1 { v15.s }[0], [x16], #0x4\n"
+    "st1 { v17.s }[0], [x8], #0x4\n"
+    "st1 { v10.s }[0], [x4], #0x4\n"
+    "st1 { v6.s }[0], [x7], #0x4\n"
+    "tbz x0, #1, 120f\n"
+    "st1 { v15.h }[2], [x16], #0x2\n"
+    "st1 { v17.h }[2], [x8], #0x2\n"
+    "st1 { v10.h }[2], [x4], #0x2\n"
+    "st1 { v6.h }[2], [x7], #0x2\n"
+    "tbz x0, #0, 123f\n"
+    "st1 { v15.b }[6], [x16], #0x1\n"
+    "st1 { v17.b }[6], [x8], #0x1\n"
+    "st1 { v10.b }[6], [x4], #0x1\n"
+    "st1 { v6.b }[6], [x7], #0x1\n"
     "b 123f\n"
     "120:"  // Oddments: Bit 2: Bit 1: Unset
-    "tbz x4, #0, 123f\n"
-    "st1 { v15.b }[4], [x17], #0x1\n"
-    "st1 { v16.b }[4], [x16], #0x1\n"
-    "st1 { v7.b }[4], [x6], #0x1\n"
-    "st1 { v8.b }[4], [x8], #0x1\n"
+    "tbz x0, #0, 123f\n"
+    "st1 { v15.b }[4], [x16], #0x1\n"
+    "st1 { v17.b }[4], [x8], #0x1\n"
+    "st1 { v10.b }[4], [x4], #0x1\n"
+    "st1 { v6.b }[4], [x7], #0x1\n"
     "b 123f\n"
     "121:"  // Oddments: Bit 2: Unset
-    "tbz x4, #1, 122f\n"
-    "st1 { v15.h }[0], [x17], #0x2\n"
-    "st1 { v16.h }[0], [x16], #0x2\n"
-    "st1 { v7.h }[0], [x6], #0x2\n"
-    "st1 { v8.h }[0], [x8], #0x2\n"
-    "tbz x4, #0, 123f\n"
-    "st1 { v15.b }[2], [x17], #0x1\n"
-    "st1 { v16.b }[2], [x16], #0x1\n"
-    "st1 { v7.b }[2], [x6], #0x1\n"
-    "st1 { v8.b }[2], [x8], #0x1\n"
+    "tbz x0, #1, 122f\n"
+    "st1 { v15.h }[0], [x16], #0x2\n"
+    "st1 { v17.h }[0], [x8], #0x2\n"
+    "st1 { v10.h }[0], [x4], #0x2\n"
+    "st1 { v6.h }[0], [x7], #0x2\n"
+    "tbz x0, #0, 123f\n"
+    "st1 { v15.b }[2], [x16], #0x1\n"
+    "st1 { v17.b }[2], [x8], #0x1\n"
+    "st1 { v10.b }[2], [x4], #0x1\n"
+    "st1 { v6.b }[2], [x7], #0x1\n"
     "b 123f\n"
     "122:"  // Oddments: Bit 2: Unset: Bit 1: Unset
-    "tbz x4, #0, 123f\n"
-    "st1 { v15.b }[0], [x17], #0x1\n"
-    "st1 { v16.b }[0], [x16], #0x1\n"
-    "st1 { v7.b }[0], [x6], #0x1\n"
-    "st1 { v8.b }[0], [x8], #0x1\n"
+    "tbz x0, #0, 123f\n"
+    "st1 { v15.b }[0], [x16], #0x1\n"
+    "st1 { v17.b }[0], [x8], #0x1\n"
+    "st1 { v10.b }[0], [x4], #0x1\n"
+    "st1 { v6.b }[0], [x7], #0x1\n"
     "123:"  // Oddments: Bit 2: End
-
     "124:"  // End
-
     :
     : [offsetof_Params_bias] "I" (offsetof(Params, bias)), [offsetof_Params_inptrs] "I" (offsetof(Params, inptrs)), [offsetof_Params_n_channels] "I" (offsetof(Params, n_channels)), [offsetof_Params_outptrs] "I" (offsetof(Params, outptrs)), [offsetof_Params_requant] "I" (offsetof(Params, requant)), [offsetof_Params_requant_muls] "I" (offsetof(Params, requant_muls)), [offsetof_Params_requant_shifts] "I" (offsetof(Params, requant_shifts)), [offsetof_Params_weights] "I" (offsetof(Params, weights)), [offsetof_Requantize32_a_offset] "I" (offsetof(arm_gemm::Requantize32, a_offset)), [offsetof_Requantize32_b_offset] "I" (offsetof(arm_gemm::Requantize32, b_offset)), [offsetof_Requantize32_c_offset] "I" (offsetof(arm_gemm::Requantize32, c_offset)), [offsetof_Requantize32_maxval] "I" (offsetof(arm_gemm::Requantize32, maxval)), [offsetof_Requantize32_minval] "I" (offsetof(arm_gemm::Requantize32, minval)), [params] "r" (&params)
     : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17", "x19", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8s8u8q_nhwc_generic_output9_mla_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8s8u8q_nhwc_generic_output9_mla_depthfirst.hpp
index 2bfeac0..6bdcca1 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8s8u8q_nhwc_generic_output9_mla_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8s8u8q_nhwc_generic_output9_mla_depthfirst.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -28,28 +28,23 @@
 
 #pragma once
 
+#if defined(__aarch64__)
+
 namespace arm_conv {
 namespace depthwise {
 
 void a64_u8s8u8q_nhwc_generic_output9_mla_depthfirst_impl(const uint8_t *const *const, uint8_t *const *const, const void *, const arm_gemm::Requantize32&, const unsigned int, const unsigned int);
 
-struct a64_u8s8u8q_nhwc_generic_output9_mla_depthfirst
+class a64_u8s8u8q_nhwc_generic_output9_mla_depthfirst : public GenericDepthfirstKernelStrategy<uint8_t, int8_t, uint8_t, int32_t>
 {
-  typedef int32_t bias_type;
-  typedef uint8_t input_type;
-  typedef int8_t weight_type;
-  typedef uint8_t return_type;
+  KernelType kernel = a64_u8s8u8q_nhwc_generic_output9_mla_depthfirst_impl;
 
-  typedef void (*kern_type)(const uint8_t *const *const, uint8_t *const *const, const void *, const arm_gemm::Requantize32&, const unsigned int, const unsigned int);
+  public:
+  a64_u8s8u8q_nhwc_generic_output9_mla_depthfirst(const CPUInfo *) : GenericDepthfirstKernelStrategy<uint8_t, int8_t, uint8_t, int32_t>(9, arm_gemm::VLType::None) {}
 
-  constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::None;
-
-  constexpr static unsigned int n_output_points = 9;
-
-  kern_type kernel = a64_u8s8u8q_nhwc_generic_output9_mla_depthfirst_impl;
-
-  a64_u8s8u8q_nhwc_generic_output9_mla_depthfirst(const CPUInfo *) {}
+  KernelType get_kernel() const override { return kernel; }
 };
 
 }  // namespace depthwise
 }  // namespace arm_conv
+#endif // defined(__aarch64__)
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8s8u8q_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8s8u8q_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst.hpp
index 8020305..394df36 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8s8u8q_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_u8s8u8q_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -28,31 +28,25 @@
 
 #pragma once
 
+#if defined(__aarch64__)
+
 namespace arm_conv {
 namespace depthwise {
 
 void a64_u8s8u8q_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst_impl(const uint8_t *const *const, uint8_t *const *const, const int8_t *, const int32_t *, const unsigned int, const unsigned int, const int32_t *, const int32_t *, const int32_t *, const arm_gemm::Requantize32&);
 
-struct a64_u8s8u8q_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst
+struct a64_u8s8u8q_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst : GenericDepthfirstMultiplierKernelStrategy<uint8_t, int8_t, uint8_t, int32_t>
 {
-  typedef int32_t bias_type;
-  typedef uint8_t input_type;
-  typedef int8_t weight_type;
-  typedef uint8_t return_type;
-
-  typedef void (*kern_type)(const uint8_t *const *const, uint8_t *const *const, const int8_t *, const int32_t *, const unsigned int, const unsigned int, const int32_t *, const int32_t *, const int32_t *, const arm_gemm::Requantize32&);
-
-  constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::None;
-
-  constexpr static unsigned int output_rows(void) { return 2; };
-  constexpr static unsigned int output_cols(void) { return 8; };
-
-  constexpr static unsigned int output_col_regs(void) { return 2; };
-
-  kern_type kernel = a64_u8s8u8q_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst_impl;
-
-  a64_u8s8u8q_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst(const CPUInfo *) {}
+  using Parent = GenericDepthfirstMultiplierKernelStrategy<uint8_t, int8_t, uint8_t, int32_t>;
+  a64_u8s8u8q_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst(const CPUInfo *)
+  : Parent(2, 8, arm_gemm::VLType::None)
+  {
+  }
+  Parent::KernelType kernel = a64_u8s8u8q_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst_impl;
+  Parent::KernelType get_kernel(void) const override { return kernel; }
 };
 
 }  // namespace depthwise
 }  // namespace arm_conv
+
+#endif  // defined(__aarch64__)
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_fp16_nhwc_3x3_s1_output2x2_mla_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_fp16_nhwc_3x3_s1_output2x2_mla_depthfirst.hpp
index 1cfea9d..1c1fb25 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_fp16_nhwc_3x3_s1_output2x2_mla_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_fp16_nhwc_3x3_s1_output2x2_mla_depthfirst.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -28,7 +28,7 @@
 
 #pragma once
 
-#if __aarch64__ && defined(ARM_COMPUTE_ENABLE_SVE) && defined(__ARM_FP16_ARGS)
+#if defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE) && defined(__ARM_FP16_ARGS)
 
 namespace arm_conv {
 namespace depthwise {
@@ -36,19 +36,16 @@
 void sve_fp16_nhwc_3x3_s1_output2x2_mla_depthfirst_indirect_impl(const __fp16 *const *const, __fp16 *const *const, const void *, unsigned int, const __fp16, const __fp16);
 void sve_fp16_nhwc_3x3_s1_output2x2_mla_depthfirst_direct_impl(const unsigned int, const unsigned int, const __fp16 *, int64_t, int64_t, __fp16 *, int64_t, int64_t, const void *, unsigned int, const __fp16, const __fp16);
 
-class sve_fp16_nhwc_3x3_s1_output2x2_mla_depthfirst : public IDepthwiseDepthfirstStrategy
+class sve_fp16_nhwc_3x3_s1_output2x2_mla_depthfirst : public DepthwiseDepthfirstStrategy<__fp16, __fp16, __fp16, __fp16>
 {
   private:
-  typedef void (*indirect_kern_type)(const __fp16 *const *const, __fp16 *const *const, const void *, unsigned int, const __fp16, const __fp16);
-  indirect_kern_type m_indirect_kernel = sve_fp16_nhwc_3x3_s1_output2x2_mla_depthfirst_indirect_impl;
-
-  typedef void (*direct_kern_type)(const unsigned int, const unsigned int, const __fp16 *, int64_t, int64_t, __fp16 *, int64_t, int64_t, const void *, unsigned int, const __fp16, const __fp16);
-  direct_kern_type m_direct_kernel = sve_fp16_nhwc_3x3_s1_output2x2_mla_depthfirst_direct_impl;
+  using Parent = DepthwiseDepthfirstStrategy<__fp16, __fp16, __fp16, __fp16>;
+  Parent::IndirectKernelType m_indirect_kernel = sve_fp16_nhwc_3x3_s1_output2x2_mla_depthfirst_indirect_impl;
+  Parent::DirectKernelType m_direct_kernel = sve_fp16_nhwc_3x3_s1_output2x2_mla_depthfirst_direct_impl;
 
   public:
-  typedef __fp16 return_type;
-
-  constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::SVE;
+  using return_type = __fp16;
+  constexpr static auto vl_type = arm_gemm::VLType::SVE;
 
   constexpr static unsigned int kernel_rows = 3;
   constexpr static unsigned int kernel_cols = 3;
@@ -59,63 +56,16 @@
   constexpr static unsigned int output_rows = 2;
   constexpr static unsigned int output_cols = 2;
 
-  constexpr static unsigned int input_rows = 4;
-  constexpr static unsigned int input_cols = 4;
-
-  sve_fp16_nhwc_3x3_s1_output2x2_mla_depthfirst(const CPUInfo *) {}
+  sve_fp16_nhwc_3x3_s1_output2x2_mla_depthfirst(const CPUInfo *)
+  : DepthwiseDepthfirstStrategy<__fp16, __fp16, __fp16, __fp16>(2, 3, 1) {}
 
   arm_gemm::VLType get_vl_type(void) const override { return vl_type; }
 
-  unsigned int get_kernel_rows(void) const override { return kernel_rows; }
-  unsigned int get_kernel_cols(void) const override { return kernel_cols; }
-
-  unsigned int get_stride_rows(void) const override { return stride_rows; }
-  unsigned int get_stride_cols(void) const override { return stride_cols; }
-
-  unsigned int get_output_rows(void) const override { return output_rows; }
-  unsigned int get_output_cols(void) const override { return output_cols; }
-
-  unsigned int get_input_rows(void) const override { return input_rows; }
-  unsigned int get_input_cols(void) const override { return input_cols; }
-
-  void indirect_kernel(
-    const void *const *const input_ptrs,
-    void *const *const outptrs,
-    const void *params,
-    unsigned int n_channels,
-    const void *activation_min,
-    const void *activation_max
-  ) const override
-  {
-    m_indirect_kernel(
-      reinterpret_cast<const __fp16 *const *>(input_ptrs),
-      reinterpret_cast<__fp16 *const *>(outptrs),
-      params, n_channels,
-      *static_cast<const __fp16 *>(activation_min),
-      *static_cast<const __fp16 *>(activation_max)
-    );
-  }
-
-  void direct_kernel(
-    const unsigned int n_tile_rows, const unsigned int n_tile_cols,
-    const void *inptr, int64_t ld_input_row, int64_t ld_input_col,
-    void *outptr, int64_t ld_output_row, int64_t ld_output_col,
-    const void *params, unsigned int n_channels,
-    const void *activation_min, const void *activation_max
-  ) const override
-  {
-    m_direct_kernel(
-      n_tile_rows, n_tile_cols,
-      static_cast<const __fp16 *>(inptr), ld_input_row, ld_input_col,
-      static_cast<__fp16 *>(outptr), ld_output_row, ld_output_col,
-      params, n_channels,
-      *static_cast<const __fp16 *>(activation_min),
-      *static_cast<const __fp16 *>(activation_max)
-    );
-  }
+  Parent::IndirectKernelType get_indirect_kernel() const override { return m_indirect_kernel; }
+  Parent::DirectKernelType get_direct_kernel() const override { return m_direct_kernel; }
 };
 
 }  // namespace depthwise
 }  // namespace arm_conv
 
-#endif  // __aarch64__ && defined(ARM_COMPUTE_ENABLE_SVE) && defined(__ARM_FP16_ARGS)
+#endif  // defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE) && defined(__ARM_FP16_ARGS)
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_fp16_nhwc_3x3_s1_output3x3_mla_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_fp16_nhwc_3x3_s1_output3x3_mla_depthfirst.hpp
index af8af18..d49b14e 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_fp16_nhwc_3x3_s1_output3x3_mla_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_fp16_nhwc_3x3_s1_output3x3_mla_depthfirst.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -28,7 +28,7 @@
 
 #pragma once
 
-#if __aarch64__ && defined(ARM_COMPUTE_ENABLE_SVE) && defined(__ARM_FP16_ARGS)
+#if defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE) &&  defined(__ARM_FP16_ARGS)
 
 namespace arm_conv {
 namespace depthwise {
@@ -36,19 +36,16 @@
 void sve_fp16_nhwc_3x3_s1_output3x3_mla_depthfirst_indirect_impl(const __fp16 *const *const, __fp16 *const *const, const void *, unsigned int, const __fp16, const __fp16);
 void sve_fp16_nhwc_3x3_s1_output3x3_mla_depthfirst_direct_impl(const unsigned int, const unsigned int, const __fp16 *, int64_t, int64_t, __fp16 *, int64_t, int64_t, const void *, unsigned int, const __fp16, const __fp16);
 
-class sve_fp16_nhwc_3x3_s1_output3x3_mla_depthfirst : public IDepthwiseDepthfirstStrategy
+class sve_fp16_nhwc_3x3_s1_output3x3_mla_depthfirst : public DepthwiseDepthfirstStrategy<__fp16, __fp16, __fp16, __fp16>
 {
   private:
-  typedef void (*indirect_kern_type)(const __fp16 *const *const, __fp16 *const *const, const void *, unsigned int, const __fp16, const __fp16);
-  indirect_kern_type m_indirect_kernel = sve_fp16_nhwc_3x3_s1_output3x3_mla_depthfirst_indirect_impl;
-
-  typedef void (*direct_kern_type)(const unsigned int, const unsigned int, const __fp16 *, int64_t, int64_t, __fp16 *, int64_t, int64_t, const void *, unsigned int, const __fp16, const __fp16);
-  direct_kern_type m_direct_kernel = sve_fp16_nhwc_3x3_s1_output3x3_mla_depthfirst_direct_impl;
+  using Parent = DepthwiseDepthfirstStrategy<__fp16, __fp16, __fp16, __fp16>;
+  Parent::IndirectKernelType m_indirect_kernel = sve_fp16_nhwc_3x3_s1_output3x3_mla_depthfirst_indirect_impl;
+  Parent::DirectKernelType m_direct_kernel = sve_fp16_nhwc_3x3_s1_output3x3_mla_depthfirst_direct_impl;
 
   public:
-  typedef __fp16 return_type;
-
-  constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::SVE;
+  using return_type = __fp16;
+  constexpr static auto vl_type = arm_gemm::VLType::SVE;
 
   constexpr static unsigned int kernel_rows = 3;
   constexpr static unsigned int kernel_cols = 3;
@@ -59,63 +56,16 @@
   constexpr static unsigned int output_rows = 3;
   constexpr static unsigned int output_cols = 3;
 
-  constexpr static unsigned int input_rows = 5;
-  constexpr static unsigned int input_cols = 5;
-
-  sve_fp16_nhwc_3x3_s1_output3x3_mla_depthfirst(const CPUInfo *) {}
+  sve_fp16_nhwc_3x3_s1_output3x3_mla_depthfirst(const CPUInfo *)
+  : DepthwiseDepthfirstStrategy<__fp16, __fp16, __fp16, __fp16>(3, 3, 1) {}
 
   arm_gemm::VLType get_vl_type(void) const override { return vl_type; }
 
-  unsigned int get_kernel_rows(void) const override { return kernel_rows; }
-  unsigned int get_kernel_cols(void) const override { return kernel_cols; }
-
-  unsigned int get_stride_rows(void) const override { return stride_rows; }
-  unsigned int get_stride_cols(void) const override { return stride_cols; }
-
-  unsigned int get_output_rows(void) const override { return output_rows; }
-  unsigned int get_output_cols(void) const override { return output_cols; }
-
-  unsigned int get_input_rows(void) const override { return input_rows; }
-  unsigned int get_input_cols(void) const override { return input_cols; }
-
-  void indirect_kernel(
-    const void *const *const input_ptrs,
-    void *const *const outptrs,
-    const void *params,
-    unsigned int n_channels,
-    const void *activation_min,
-    const void *activation_max
-  ) const override
-  {
-    m_indirect_kernel(
-      reinterpret_cast<const __fp16 *const *>(input_ptrs),
-      reinterpret_cast<__fp16 *const *>(outptrs),
-      params, n_channels,
-      *static_cast<const __fp16 *>(activation_min),
-      *static_cast<const __fp16 *>(activation_max)
-    );
-  }
-
-  void direct_kernel(
-    const unsigned int n_tile_rows, const unsigned int n_tile_cols,
-    const void *inptr, int64_t ld_input_row, int64_t ld_input_col,
-    void *outptr, int64_t ld_output_row, int64_t ld_output_col,
-    const void *params, unsigned int n_channels,
-    const void *activation_min, const void *activation_max
-  ) const override
-  {
-    m_direct_kernel(
-      n_tile_rows, n_tile_cols,
-      static_cast<const __fp16 *>(inptr), ld_input_row, ld_input_col,
-      static_cast<__fp16 *>(outptr), ld_output_row, ld_output_col,
-      params, n_channels,
-      *static_cast<const __fp16 *>(activation_min),
-      *static_cast<const __fp16 *>(activation_max)
-    );
-  }
+  Parent::IndirectKernelType get_indirect_kernel() const override { return m_indirect_kernel; }
+  Parent::DirectKernelType get_direct_kernel() const override { return m_direct_kernel; }
 };
 
 }  // namespace depthwise
 }  // namespace arm_conv
 
-#endif  // __aarch64__ && defined(ARM_COMPUTE_ENABLE_SVE) && defined(__ARM_FP16_ARGS)
+#endif  // defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE) &&  defined(__ARM_FP16_ARGS)
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_fp16_nhwc_3x3_s1_output4x4_mla_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_fp16_nhwc_3x3_s1_output4x4_mla_depthfirst.hpp
index 60234c8..ac6ae28 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_fp16_nhwc_3x3_s1_output4x4_mla_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_fp16_nhwc_3x3_s1_output4x4_mla_depthfirst.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -28,7 +28,7 @@
 
 #pragma once
 
-#if __aarch64__ && defined(ARM_COMPUTE_ENABLE_SVE) && defined(__ARM_FP16_ARGS)
+#if defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE) && defined(__ARM_FP16_ARGS)
 
 namespace arm_conv {
 namespace depthwise {
@@ -36,19 +36,16 @@
 void sve_fp16_nhwc_3x3_s1_output4x4_mla_depthfirst_indirect_impl(const __fp16 *const *const, __fp16 *const *const, const void *, unsigned int, const __fp16, const __fp16);
 void sve_fp16_nhwc_3x3_s1_output4x4_mla_depthfirst_direct_impl(const unsigned int, const unsigned int, const __fp16 *, int64_t, int64_t, __fp16 *, int64_t, int64_t, const void *, unsigned int, const __fp16, const __fp16);
 
-class sve_fp16_nhwc_3x3_s1_output4x4_mla_depthfirst : public IDepthwiseDepthfirstStrategy
+class sve_fp16_nhwc_3x3_s1_output4x4_mla_depthfirst : public DepthwiseDepthfirstStrategy<__fp16, __fp16, __fp16, __fp16>
 {
   private:
-  typedef void (*indirect_kern_type)(const __fp16 *const *const, __fp16 *const *const, const void *, unsigned int, const __fp16, const __fp16);
-  indirect_kern_type m_indirect_kernel = sve_fp16_nhwc_3x3_s1_output4x4_mla_depthfirst_indirect_impl;
-
-  typedef void (*direct_kern_type)(const unsigned int, const unsigned int, const __fp16 *, int64_t, int64_t, __fp16 *, int64_t, int64_t, const void *, unsigned int, const __fp16, const __fp16);
-  direct_kern_type m_direct_kernel = sve_fp16_nhwc_3x3_s1_output4x4_mla_depthfirst_direct_impl;
+  using Parent = DepthwiseDepthfirstStrategy<__fp16, __fp16, __fp16, __fp16>;
+  Parent::IndirectKernelType m_indirect_kernel = sve_fp16_nhwc_3x3_s1_output4x4_mla_depthfirst_indirect_impl;
+  Parent::DirectKernelType m_direct_kernel = sve_fp16_nhwc_3x3_s1_output4x4_mla_depthfirst_direct_impl;
 
   public:
-  typedef __fp16 return_type;
-
-  constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::SVE;
+  using return_type = __fp16;
+  constexpr static auto vl_type = arm_gemm::VLType::SVE;
 
   constexpr static unsigned int kernel_rows = 3;
   constexpr static unsigned int kernel_cols = 3;
@@ -59,63 +56,16 @@
   constexpr static unsigned int output_rows = 4;
   constexpr static unsigned int output_cols = 4;
 
-  constexpr static unsigned int input_rows = 6;
-  constexpr static unsigned int input_cols = 6;
-
-  sve_fp16_nhwc_3x3_s1_output4x4_mla_depthfirst(const CPUInfo *) {}
+  sve_fp16_nhwc_3x3_s1_output4x4_mla_depthfirst(const CPUInfo *)
+  : DepthwiseDepthfirstStrategy<__fp16, __fp16, __fp16, __fp16>(4, 3, 1) {}
 
   arm_gemm::VLType get_vl_type(void) const override { return vl_type; }
 
-  unsigned int get_kernel_rows(void) const override { return kernel_rows; }
-  unsigned int get_kernel_cols(void) const override { return kernel_cols; }
-
-  unsigned int get_stride_rows(void) const override { return stride_rows; }
-  unsigned int get_stride_cols(void) const override { return stride_cols; }
-
-  unsigned int get_output_rows(void) const override { return output_rows; }
-  unsigned int get_output_cols(void) const override { return output_cols; }
-
-  unsigned int get_input_rows(void) const override { return input_rows; }
-  unsigned int get_input_cols(void) const override { return input_cols; }
-
-  void indirect_kernel(
-    const void *const *const input_ptrs,
-    void *const *const outptrs,
-    const void *params,
-    unsigned int n_channels,
-    const void *activation_min,
-    const void *activation_max
-  ) const override
-  {
-    m_indirect_kernel(
-      reinterpret_cast<const __fp16 *const *>(input_ptrs),
-      reinterpret_cast<__fp16 *const *>(outptrs),
-      params, n_channels,
-      *static_cast<const __fp16 *>(activation_min),
-      *static_cast<const __fp16 *>(activation_max)
-    );
-  }
-
-  void direct_kernel(
-    const unsigned int n_tile_rows, const unsigned int n_tile_cols,
-    const void *inptr, int64_t ld_input_row, int64_t ld_input_col,
-    void *outptr, int64_t ld_output_row, int64_t ld_output_col,
-    const void *params, unsigned int n_channels,
-    const void *activation_min, const void *activation_max
-  ) const override
-  {
-    m_direct_kernel(
-      n_tile_rows, n_tile_cols,
-      static_cast<const __fp16 *>(inptr), ld_input_row, ld_input_col,
-      static_cast<__fp16 *>(outptr), ld_output_row, ld_output_col,
-      params, n_channels,
-      *static_cast<const __fp16 *>(activation_min),
-      *static_cast<const __fp16 *>(activation_max)
-    );
-  }
+  Parent::IndirectKernelType get_indirect_kernel() const override { return m_indirect_kernel; }
+  Parent::DirectKernelType get_direct_kernel() const override { return m_direct_kernel; }
 };
 
 }  // namespace depthwise
 }  // namespace arm_conv
 
-#endif  // __aarch64__ && defined(ARM_COMPUTE_ENABLE_SVE) && defined(__ARM_FP16_ARGS)
+#endif  // defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE) && defined(__ARM_FP16_ARGS)
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_fp16_nhwc_3x3_s2_output2x2_mla_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_fp16_nhwc_3x3_s2_output2x2_mla_depthfirst.hpp
index 5968309..82173ee 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_fp16_nhwc_3x3_s2_output2x2_mla_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_fp16_nhwc_3x3_s2_output2x2_mla_depthfirst.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -28,7 +28,7 @@
 
 #pragma once
 
-#if __aarch64__ && defined(ARM_COMPUTE_ENABLE_SVE) && defined(__ARM_FP16_ARGS)
+#if defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE) && defined(__ARM_FP16_ARGS)
 
 namespace arm_conv {
 namespace depthwise {
@@ -36,19 +36,16 @@
 void sve_fp16_nhwc_3x3_s2_output2x2_mla_depthfirst_indirect_impl(const __fp16 *const *const, __fp16 *const *const, const void *, unsigned int, const __fp16, const __fp16);
 void sve_fp16_nhwc_3x3_s2_output2x2_mla_depthfirst_direct_impl(const unsigned int, const unsigned int, const __fp16 *, int64_t, int64_t, __fp16 *, int64_t, int64_t, const void *, unsigned int, const __fp16, const __fp16);
 
-class sve_fp16_nhwc_3x3_s2_output2x2_mla_depthfirst : public IDepthwiseDepthfirstStrategy
+class sve_fp16_nhwc_3x3_s2_output2x2_mla_depthfirst : public DepthwiseDepthfirstStrategy<__fp16, __fp16, __fp16, __fp16>
 {
   private:
-  typedef void (*indirect_kern_type)(const __fp16 *const *const, __fp16 *const *const, const void *, unsigned int, const __fp16, const __fp16);
-  indirect_kern_type m_indirect_kernel = sve_fp16_nhwc_3x3_s2_output2x2_mla_depthfirst_indirect_impl;
-
-  typedef void (*direct_kern_type)(const unsigned int, const unsigned int, const __fp16 *, int64_t, int64_t, __fp16 *, int64_t, int64_t, const void *, unsigned int, const __fp16, const __fp16);
-  direct_kern_type m_direct_kernel = sve_fp16_nhwc_3x3_s2_output2x2_mla_depthfirst_direct_impl;
+  using Parent = DepthwiseDepthfirstStrategy<__fp16, __fp16, __fp16, __fp16>;
+  Parent::IndirectKernelType m_indirect_kernel = sve_fp16_nhwc_3x3_s2_output2x2_mla_depthfirst_indirect_impl;
+  Parent::DirectKernelType m_direct_kernel = sve_fp16_nhwc_3x3_s2_output2x2_mla_depthfirst_direct_impl;
 
   public:
-  typedef __fp16 return_type;
-
-  constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::SVE;
+  using return_type = __fp16;
+  constexpr static auto vl_type = arm_gemm::VLType::SVE;
 
   constexpr static unsigned int kernel_rows = 3;
   constexpr static unsigned int kernel_cols = 3;
@@ -59,63 +56,16 @@
   constexpr static unsigned int output_rows = 2;
   constexpr static unsigned int output_cols = 2;
 
-  constexpr static unsigned int input_rows = 5;
-  constexpr static unsigned int input_cols = 5;
-
-  sve_fp16_nhwc_3x3_s2_output2x2_mla_depthfirst(const CPUInfo *) {}
+  sve_fp16_nhwc_3x3_s2_output2x2_mla_depthfirst(const CPUInfo *)
+  : DepthwiseDepthfirstStrategy<__fp16, __fp16, __fp16, __fp16>(2, 3, 2) {}
 
   arm_gemm::VLType get_vl_type(void) const override { return vl_type; }
 
-  unsigned int get_kernel_rows(void) const override { return kernel_rows; }
-  unsigned int get_kernel_cols(void) const override { return kernel_cols; }
-
-  unsigned int get_stride_rows(void) const override { return stride_rows; }
-  unsigned int get_stride_cols(void) const override { return stride_cols; }
-
-  unsigned int get_output_rows(void) const override { return output_rows; }
-  unsigned int get_output_cols(void) const override { return output_cols; }
-
-  unsigned int get_input_rows(void) const override { return input_rows; }
-  unsigned int get_input_cols(void) const override { return input_cols; }
-
-  void indirect_kernel(
-    const void *const *const input_ptrs,
-    void *const *const outptrs,
-    const void *params,
-    unsigned int n_channels,
-    const void *activation_min,
-    const void *activation_max
-  ) const override
-  {
-    m_indirect_kernel(
-      reinterpret_cast<const __fp16 *const *>(input_ptrs),
-      reinterpret_cast<__fp16 *const *>(outptrs),
-      params, n_channels,
-      *static_cast<const __fp16 *>(activation_min),
-      *static_cast<const __fp16 *>(activation_max)
-    );
-  }
-
-  void direct_kernel(
-    const unsigned int n_tile_rows, const unsigned int n_tile_cols,
-    const void *inptr, int64_t ld_input_row, int64_t ld_input_col,
-    void *outptr, int64_t ld_output_row, int64_t ld_output_col,
-    const void *params, unsigned int n_channels,
-    const void *activation_min, const void *activation_max
-  ) const override
-  {
-    m_direct_kernel(
-      n_tile_rows, n_tile_cols,
-      static_cast<const __fp16 *>(inptr), ld_input_row, ld_input_col,
-      static_cast<__fp16 *>(outptr), ld_output_row, ld_output_col,
-      params, n_channels,
-      *static_cast<const __fp16 *>(activation_min),
-      *static_cast<const __fp16 *>(activation_max)
-    );
-  }
+  Parent::IndirectKernelType get_indirect_kernel() const override { return m_indirect_kernel; }
+  Parent::DirectKernelType get_direct_kernel() const override { return m_direct_kernel; }
 };
 
 }  // namespace depthwise
 }  // namespace arm_conv
 
-#endif  // __aarch64__ && defined(ARM_COMPUTE_ENABLE_SVE) && defined(__ARM_FP16_ARGS)
+#endif  // defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE) && defined(__ARM_FP16_ARGS)
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_fp16_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_fp16_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp
index 4a9bd33..f5d4189 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_fp16_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_fp16_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -28,7 +28,7 @@
 
 #pragma once
 
-#if __aarch64__ && defined(ARM_COMPUTE_ENABLE_SVE) && defined(__ARM_FP16_ARGS)
+#if defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE) && defined(__ARM_FP16_ARGS)
 
 namespace arm_conv {
 namespace depthwise {
@@ -36,19 +36,16 @@
 void sve_fp16_nhwc_5x5_s1_output2x2_mla_depthfirst_indirect_impl(const __fp16 *const *const, __fp16 *const *const, const void *, unsigned int, const __fp16, const __fp16);
 void sve_fp16_nhwc_5x5_s1_output2x2_mla_depthfirst_direct_impl(const unsigned int, const unsigned int, const __fp16 *, int64_t, int64_t, __fp16 *, int64_t, int64_t, const void *, unsigned int, const __fp16, const __fp16);
 
-class sve_fp16_nhwc_5x5_s1_output2x2_mla_depthfirst : public IDepthwiseDepthfirstStrategy
+class sve_fp16_nhwc_5x5_s1_output2x2_mla_depthfirst : public DepthwiseDepthfirstStrategy<__fp16, __fp16, __fp16, __fp16>
 {
   private:
-  typedef void (*indirect_kern_type)(const __fp16 *const *const, __fp16 *const *const, const void *, unsigned int, const __fp16, const __fp16);
-  indirect_kern_type m_indirect_kernel = sve_fp16_nhwc_5x5_s1_output2x2_mla_depthfirst_indirect_impl;
-
-  typedef void (*direct_kern_type)(const unsigned int, const unsigned int, const __fp16 *, int64_t, int64_t, __fp16 *, int64_t, int64_t, const void *, unsigned int, const __fp16, const __fp16);
-  direct_kern_type m_direct_kernel = sve_fp16_nhwc_5x5_s1_output2x2_mla_depthfirst_direct_impl;
+  using Parent = DepthwiseDepthfirstStrategy<__fp16, __fp16, __fp16, __fp16>;
+  Parent::IndirectKernelType m_indirect_kernel = sve_fp16_nhwc_5x5_s1_output2x2_mla_depthfirst_indirect_impl;
+  Parent::DirectKernelType m_direct_kernel = sve_fp16_nhwc_5x5_s1_output2x2_mla_depthfirst_direct_impl;
 
   public:
-  typedef __fp16 return_type;
-
-  constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::SVE;
+  using return_type = __fp16;
+  constexpr static auto vl_type = arm_gemm::VLType::SVE;
 
   constexpr static unsigned int kernel_rows = 5;
   constexpr static unsigned int kernel_cols = 5;
@@ -59,63 +56,16 @@
   constexpr static unsigned int output_rows = 2;
   constexpr static unsigned int output_cols = 2;
 
-  constexpr static unsigned int input_rows = 6;
-  constexpr static unsigned int input_cols = 6;
-
-  sve_fp16_nhwc_5x5_s1_output2x2_mla_depthfirst(const CPUInfo *) {}
+  sve_fp16_nhwc_5x5_s1_output2x2_mla_depthfirst(const CPUInfo *)
+  : DepthwiseDepthfirstStrategy<__fp16, __fp16, __fp16, __fp16>(2, 5, 1) {}
 
   arm_gemm::VLType get_vl_type(void) const override { return vl_type; }
 
-  unsigned int get_kernel_rows(void) const override { return kernel_rows; }
-  unsigned int get_kernel_cols(void) const override { return kernel_cols; }
-
-  unsigned int get_stride_rows(void) const override { return stride_rows; }
-  unsigned int get_stride_cols(void) const override { return stride_cols; }
-
-  unsigned int get_output_rows(void) const override { return output_rows; }
-  unsigned int get_output_cols(void) const override { return output_cols; }
-
-  unsigned int get_input_rows(void) const override { return input_rows; }
-  unsigned int get_input_cols(void) const override { return input_cols; }
-
-  void indirect_kernel(
-    const void *const *const input_ptrs,
-    void *const *const outptrs,
-    const void *params,
-    unsigned int n_channels,
-    const void *activation_min,
-    const void *activation_max
-  ) const override
-  {
-    m_indirect_kernel(
-      reinterpret_cast<const __fp16 *const *>(input_ptrs),
-      reinterpret_cast<__fp16 *const *>(outptrs),
-      params, n_channels,
-      *static_cast<const __fp16 *>(activation_min),
-      *static_cast<const __fp16 *>(activation_max)
-    );
-  }
-
-  void direct_kernel(
-    const unsigned int n_tile_rows, const unsigned int n_tile_cols,
-    const void *inptr, int64_t ld_input_row, int64_t ld_input_col,
-    void *outptr, int64_t ld_output_row, int64_t ld_output_col,
-    const void *params, unsigned int n_channels,
-    const void *activation_min, const void *activation_max
-  ) const override
-  {
-    m_direct_kernel(
-      n_tile_rows, n_tile_cols,
-      static_cast<const __fp16 *>(inptr), ld_input_row, ld_input_col,
-      static_cast<__fp16 *>(outptr), ld_output_row, ld_output_col,
-      params, n_channels,
-      *static_cast<const __fp16 *>(activation_min),
-      *static_cast<const __fp16 *>(activation_max)
-    );
-  }
+  Parent::IndirectKernelType get_indirect_kernel() const override { return m_indirect_kernel; }
+  Parent::DirectKernelType get_direct_kernel() const override { return m_direct_kernel; }
 };
 
 }  // namespace depthwise
 }  // namespace arm_conv
 
-#endif  // __aarch64__ && defined(ARM_COMPUTE_ENABLE_SVE) && defined(__ARM_FP16_ARGS)
+#endif  // defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE) && defined(__ARM_FP16_ARGS)
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_fp32_nhwc_3x3_s1_output2x2_mla_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_fp32_nhwc_3x3_s1_output2x2_mla_depthfirst.hpp
index e07e631..d7b1de2 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_fp32_nhwc_3x3_s1_output2x2_mla_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_fp32_nhwc_3x3_s1_output2x2_mla_depthfirst.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -28,7 +28,7 @@
 
 #pragma once
 
-#if __aarch64__ && defined(ARM_COMPUTE_ENABLE_SVE)
+#if defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE)
 
 namespace arm_conv {
 namespace depthwise {
@@ -36,19 +36,16 @@
 void sve_fp32_nhwc_3x3_s1_output2x2_mla_depthfirst_indirect_impl(const float *const *const, float *const *const, const void *, unsigned int, const float, const float);
 void sve_fp32_nhwc_3x3_s1_output2x2_mla_depthfirst_direct_impl(const unsigned int, const unsigned int, const float *, int64_t, int64_t, float *, int64_t, int64_t, const void *, unsigned int, const float, const float);
 
-class sve_fp32_nhwc_3x3_s1_output2x2_mla_depthfirst : public IDepthwiseDepthfirstStrategy
+class sve_fp32_nhwc_3x3_s1_output2x2_mla_depthfirst : public DepthwiseDepthfirstStrategy<float, float, float, float>
 {
   private:
-  typedef void (*indirect_kern_type)(const float *const *const, float *const *const, const void *, unsigned int, const float, const float);
-  indirect_kern_type m_indirect_kernel = sve_fp32_nhwc_3x3_s1_output2x2_mla_depthfirst_indirect_impl;
-
-  typedef void (*direct_kern_type)(const unsigned int, const unsigned int, const float *, int64_t, int64_t, float *, int64_t, int64_t, const void *, unsigned int, const float, const float);
-  direct_kern_type m_direct_kernel = sve_fp32_nhwc_3x3_s1_output2x2_mla_depthfirst_direct_impl;
+  using Parent = DepthwiseDepthfirstStrategy<float, float, float, float>;
+  Parent::IndirectKernelType m_indirect_kernel = sve_fp32_nhwc_3x3_s1_output2x2_mla_depthfirst_indirect_impl;
+  Parent::DirectKernelType m_direct_kernel = sve_fp32_nhwc_3x3_s1_output2x2_mla_depthfirst_direct_impl;
 
   public:
-  typedef float return_type;
-
-  constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::SVE;
+  using return_type = float;
+  constexpr static auto vl_type = arm_gemm::VLType::SVE;
 
   constexpr static unsigned int kernel_rows = 3;
   constexpr static unsigned int kernel_cols = 3;
@@ -59,63 +56,16 @@
   constexpr static unsigned int output_rows = 2;
   constexpr static unsigned int output_cols = 2;
 
-  constexpr static unsigned int input_rows = 4;
-  constexpr static unsigned int input_cols = 4;
-
-  sve_fp32_nhwc_3x3_s1_output2x2_mla_depthfirst(const CPUInfo *) {}
+  sve_fp32_nhwc_3x3_s1_output2x2_mla_depthfirst(const CPUInfo *)
+  : DepthwiseDepthfirstStrategy<float, float, float, float>(2, 3, 1) {}
 
   arm_gemm::VLType get_vl_type(void) const override { return vl_type; }
 
-  unsigned int get_kernel_rows(void) const override { return kernel_rows; }
-  unsigned int get_kernel_cols(void) const override { return kernel_cols; }
-
-  unsigned int get_stride_rows(void) const override { return stride_rows; }
-  unsigned int get_stride_cols(void) const override { return stride_cols; }
-
-  unsigned int get_output_rows(void) const override { return output_rows; }
-  unsigned int get_output_cols(void) const override { return output_cols; }
-
-  unsigned int get_input_rows(void) const override { return input_rows; }
-  unsigned int get_input_cols(void) const override { return input_cols; }
-
-  void indirect_kernel(
-    const void *const *const input_ptrs,
-    void *const *const outptrs,
-    const void *params,
-    unsigned int n_channels,
-    const void *activation_min,
-    const void *activation_max
-  ) const override
-  {
-    m_indirect_kernel(
-      reinterpret_cast<const float *const *>(input_ptrs),
-      reinterpret_cast<float *const *>(outptrs),
-      params, n_channels,
-      *static_cast<const float *>(activation_min),
-      *static_cast<const float *>(activation_max)
-    );
-  }
-
-  void direct_kernel(
-    const unsigned int n_tile_rows, const unsigned int n_tile_cols,
-    const void *inptr, int64_t ld_input_row, int64_t ld_input_col,
-    void *outptr, int64_t ld_output_row, int64_t ld_output_col,
-    const void *params, unsigned int n_channels,
-    const void *activation_min, const void *activation_max
-  ) const override
-  {
-    m_direct_kernel(
-      n_tile_rows, n_tile_cols,
-      static_cast<const float *>(inptr), ld_input_row, ld_input_col,
-      static_cast<float *>(outptr), ld_output_row, ld_output_col,
-      params, n_channels,
-      *static_cast<const float *>(activation_min),
-      *static_cast<const float *>(activation_max)
-    );
-  }
+  Parent::IndirectKernelType get_indirect_kernel() const override { return m_indirect_kernel; }
+  Parent::DirectKernelType get_direct_kernel() const override { return m_direct_kernel; }
 };
 
 }  // namespace depthwise
 }  // namespace arm_conv
 
-#endif  // __aarch64__ && defined(ARM_COMPUTE_ENABLE_SVE)
+#endif  // defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE)
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_fp32_nhwc_3x3_s1_output3x3_mla_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_fp32_nhwc_3x3_s1_output3x3_mla_depthfirst.hpp
index eb9de9f..41ad193 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_fp32_nhwc_3x3_s1_output3x3_mla_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_fp32_nhwc_3x3_s1_output3x3_mla_depthfirst.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -28,7 +28,7 @@
 
 #pragma once
 
-#if __aarch64__ && defined(ARM_COMPUTE_ENABLE_SVE)
+#if defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE)
 
 namespace arm_conv {
 namespace depthwise {
@@ -36,19 +36,16 @@
 void sve_fp32_nhwc_3x3_s1_output3x3_mla_depthfirst_indirect_impl(const float *const *const, float *const *const, const void *, unsigned int, const float, const float);
 void sve_fp32_nhwc_3x3_s1_output3x3_mla_depthfirst_direct_impl(const unsigned int, const unsigned int, const float *, int64_t, int64_t, float *, int64_t, int64_t, const void *, unsigned int, const float, const float);
 
-class sve_fp32_nhwc_3x3_s1_output3x3_mla_depthfirst : public IDepthwiseDepthfirstStrategy
+class sve_fp32_nhwc_3x3_s1_output3x3_mla_depthfirst : public DepthwiseDepthfirstStrategy<float, float, float, float>
 {
   private:
-  typedef void (*indirect_kern_type)(const float *const *const, float *const *const, const void *, unsigned int, const float, const float);
-  indirect_kern_type m_indirect_kernel = sve_fp32_nhwc_3x3_s1_output3x3_mla_depthfirst_indirect_impl;
-
-  typedef void (*direct_kern_type)(const unsigned int, const unsigned int, const float *, int64_t, int64_t, float *, int64_t, int64_t, const void *, unsigned int, const float, const float);
-  direct_kern_type m_direct_kernel = sve_fp32_nhwc_3x3_s1_output3x3_mla_depthfirst_direct_impl;
+  using Parent = DepthwiseDepthfirstStrategy<float, float, float, float>;
+  Parent::IndirectKernelType m_indirect_kernel = sve_fp32_nhwc_3x3_s1_output3x3_mla_depthfirst_indirect_impl;
+  Parent::DirectKernelType m_direct_kernel = sve_fp32_nhwc_3x3_s1_output3x3_mla_depthfirst_direct_impl;
 
   public:
-  typedef float return_type;
-
-  constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::SVE;
+  using return_type = float;
+  constexpr static auto vl_type = arm_gemm::VLType::SVE;
 
   constexpr static unsigned int kernel_rows = 3;
   constexpr static unsigned int kernel_cols = 3;
@@ -59,63 +56,16 @@
   constexpr static unsigned int output_rows = 3;
   constexpr static unsigned int output_cols = 3;
 
-  constexpr static unsigned int input_rows = 5;
-  constexpr static unsigned int input_cols = 5;
-
-  sve_fp32_nhwc_3x3_s1_output3x3_mla_depthfirst(const CPUInfo *) {}
+  sve_fp32_nhwc_3x3_s1_output3x3_mla_depthfirst(const CPUInfo *)
+  : DepthwiseDepthfirstStrategy<float, float, float, float>(3, 3, 1) {}
 
   arm_gemm::VLType get_vl_type(void) const override { return vl_type; }
 
-  unsigned int get_kernel_rows(void) const override { return kernel_rows; }
-  unsigned int get_kernel_cols(void) const override { return kernel_cols; }
-
-  unsigned int get_stride_rows(void) const override { return stride_rows; }
-  unsigned int get_stride_cols(void) const override { return stride_cols; }
-
-  unsigned int get_output_rows(void) const override { return output_rows; }
-  unsigned int get_output_cols(void) const override { return output_cols; }
-
-  unsigned int get_input_rows(void) const override { return input_rows; }
-  unsigned int get_input_cols(void) const override { return input_cols; }
-
-  void indirect_kernel(
-    const void *const *const input_ptrs,
-    void *const *const outptrs,
-    const void *params,
-    unsigned int n_channels,
-    const void *activation_min,
-    const void *activation_max
-  ) const override
-  {
-    m_indirect_kernel(
-      reinterpret_cast<const float *const *>(input_ptrs),
-      reinterpret_cast<float *const *>(outptrs),
-      params, n_channels,
-      *static_cast<const float *>(activation_min),
-      *static_cast<const float *>(activation_max)
-    );
-  }
-
-  void direct_kernel(
-    const unsigned int n_tile_rows, const unsigned int n_tile_cols,
-    const void *inptr, int64_t ld_input_row, int64_t ld_input_col,
-    void *outptr, int64_t ld_output_row, int64_t ld_output_col,
-    const void *params, unsigned int n_channels,
-    const void *activation_min, const void *activation_max
-  ) const override
-  {
-    m_direct_kernel(
-      n_tile_rows, n_tile_cols,
-      static_cast<const float *>(inptr), ld_input_row, ld_input_col,
-      static_cast<float *>(outptr), ld_output_row, ld_output_col,
-      params, n_channels,
-      *static_cast<const float *>(activation_min),
-      *static_cast<const float *>(activation_max)
-    );
-  }
+  Parent::IndirectKernelType get_indirect_kernel() const override { return m_indirect_kernel; }
+  Parent::DirectKernelType get_direct_kernel() const override { return m_direct_kernel; }
 };
 
 }  // namespace depthwise
 }  // namespace arm_conv
 
-#endif  // __aarch64__ && defined(ARM_COMPUTE_ENABLE_SVE)
+#endif  // defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE)
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_fp32_nhwc_3x3_s1_output4x4_mla_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_fp32_nhwc_3x3_s1_output4x4_mla_depthfirst.hpp
index d7be9b1..6073b2b 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_fp32_nhwc_3x3_s1_output4x4_mla_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_fp32_nhwc_3x3_s1_output4x4_mla_depthfirst.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -28,7 +28,7 @@
 
 #pragma once
 
-#if __aarch64__ && defined(ARM_COMPUTE_ENABLE_SVE)
+#if defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE)
 
 namespace arm_conv {
 namespace depthwise {
@@ -36,19 +36,16 @@
 void sve_fp32_nhwc_3x3_s1_output4x4_mla_depthfirst_indirect_impl(const float *const *const, float *const *const, const void *, unsigned int, const float, const float);
 void sve_fp32_nhwc_3x3_s1_output4x4_mla_depthfirst_direct_impl(const unsigned int, const unsigned int, const float *, int64_t, int64_t, float *, int64_t, int64_t, const void *, unsigned int, const float, const float);
 
-class sve_fp32_nhwc_3x3_s1_output4x4_mla_depthfirst : public IDepthwiseDepthfirstStrategy
+class sve_fp32_nhwc_3x3_s1_output4x4_mla_depthfirst : public DepthwiseDepthfirstStrategy<float, float, float, float>
 {
   private:
-  typedef void (*indirect_kern_type)(const float *const *const, float *const *const, const void *, unsigned int, const float, const float);
-  indirect_kern_type m_indirect_kernel = sve_fp32_nhwc_3x3_s1_output4x4_mla_depthfirst_indirect_impl;
-
-  typedef void (*direct_kern_type)(const unsigned int, const unsigned int, const float *, int64_t, int64_t, float *, int64_t, int64_t, const void *, unsigned int, const float, const float);
-  direct_kern_type m_direct_kernel = sve_fp32_nhwc_3x3_s1_output4x4_mla_depthfirst_direct_impl;
+  using Parent = DepthwiseDepthfirstStrategy<float, float, float, float>;
+  Parent::IndirectKernelType m_indirect_kernel = sve_fp32_nhwc_3x3_s1_output4x4_mla_depthfirst_indirect_impl;
+  Parent::DirectKernelType m_direct_kernel = sve_fp32_nhwc_3x3_s1_output4x4_mla_depthfirst_direct_impl;
 
   public:
-  typedef float return_type;
-
-  constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::SVE;
+  using return_type = float;
+  constexpr static auto vl_type = arm_gemm::VLType::SVE;
 
   constexpr static unsigned int kernel_rows = 3;
   constexpr static unsigned int kernel_cols = 3;
@@ -59,63 +56,16 @@
   constexpr static unsigned int output_rows = 4;
   constexpr static unsigned int output_cols = 4;
 
-  constexpr static unsigned int input_rows = 6;
-  constexpr static unsigned int input_cols = 6;
-
-  sve_fp32_nhwc_3x3_s1_output4x4_mla_depthfirst(const CPUInfo *) {}
+  sve_fp32_nhwc_3x3_s1_output4x4_mla_depthfirst(const CPUInfo *)
+  : DepthwiseDepthfirstStrategy<float, float, float, float>(4, 3, 1) {}
 
   arm_gemm::VLType get_vl_type(void) const override { return vl_type; }
 
-  unsigned int get_kernel_rows(void) const override { return kernel_rows; }
-  unsigned int get_kernel_cols(void) const override { return kernel_cols; }
-
-  unsigned int get_stride_rows(void) const override { return stride_rows; }
-  unsigned int get_stride_cols(void) const override { return stride_cols; }
-
-  unsigned int get_output_rows(void) const override { return output_rows; }
-  unsigned int get_output_cols(void) const override { return output_cols; }
-
-  unsigned int get_input_rows(void) const override { return input_rows; }
-  unsigned int get_input_cols(void) const override { return input_cols; }
-
-  void indirect_kernel(
-    const void *const *const input_ptrs,
-    void *const *const outptrs,
-    const void *params,
-    unsigned int n_channels,
-    const void *activation_min,
-    const void *activation_max
-  ) const override
-  {
-    m_indirect_kernel(
-      reinterpret_cast<const float *const *>(input_ptrs),
-      reinterpret_cast<float *const *>(outptrs),
-      params, n_channels,
-      *static_cast<const float *>(activation_min),
-      *static_cast<const float *>(activation_max)
-    );
-  }
-
-  void direct_kernel(
-    const unsigned int n_tile_rows, const unsigned int n_tile_cols,
-    const void *inptr, int64_t ld_input_row, int64_t ld_input_col,
-    void *outptr, int64_t ld_output_row, int64_t ld_output_col,
-    const void *params, unsigned int n_channels,
-    const void *activation_min, const void *activation_max
-  ) const override
-  {
-    m_direct_kernel(
-      n_tile_rows, n_tile_cols,
-      static_cast<const float *>(inptr), ld_input_row, ld_input_col,
-      static_cast<float *>(outptr), ld_output_row, ld_output_col,
-      params, n_channels,
-      *static_cast<const float *>(activation_min),
-      *static_cast<const float *>(activation_max)
-    );
-  }
+  Parent::IndirectKernelType get_indirect_kernel() const override { return m_indirect_kernel; }
+  Parent::DirectKernelType get_direct_kernel() const override { return m_direct_kernel; }
 };
 
 }  // namespace depthwise
 }  // namespace arm_conv
 
-#endif  // __aarch64__ && defined(ARM_COMPUTE_ENABLE_SVE)
+#endif  // defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE)
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_fp32_nhwc_3x3_s2_output2x2_mla_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_fp32_nhwc_3x3_s2_output2x2_mla_depthfirst.hpp
index 28d44d0..17ac74e 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_fp32_nhwc_3x3_s2_output2x2_mla_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_fp32_nhwc_3x3_s2_output2x2_mla_depthfirst.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -28,7 +28,7 @@
 
 #pragma once
 
-#if __aarch64__ && defined(ARM_COMPUTE_ENABLE_SVE)
+#if defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE)
 
 namespace arm_conv {
 namespace depthwise {
@@ -36,19 +36,16 @@
 void sve_fp32_nhwc_3x3_s2_output2x2_mla_depthfirst_indirect_impl(const float *const *const, float *const *const, const void *, unsigned int, const float, const float);
 void sve_fp32_nhwc_3x3_s2_output2x2_mla_depthfirst_direct_impl(const unsigned int, const unsigned int, const float *, int64_t, int64_t, float *, int64_t, int64_t, const void *, unsigned int, const float, const float);
 
-class sve_fp32_nhwc_3x3_s2_output2x2_mla_depthfirst : public IDepthwiseDepthfirstStrategy
+class sve_fp32_nhwc_3x3_s2_output2x2_mla_depthfirst : public DepthwiseDepthfirstStrategy<float, float, float, float>
 {
   private:
-  typedef void (*indirect_kern_type)(const float *const *const, float *const *const, const void *, unsigned int, const float, const float);
-  indirect_kern_type m_indirect_kernel = sve_fp32_nhwc_3x3_s2_output2x2_mla_depthfirst_indirect_impl;
-
-  typedef void (*direct_kern_type)(const unsigned int, const unsigned int, const float *, int64_t, int64_t, float *, int64_t, int64_t, const void *, unsigned int, const float, const float);
-  direct_kern_type m_direct_kernel = sve_fp32_nhwc_3x3_s2_output2x2_mla_depthfirst_direct_impl;
+  using Parent = DepthwiseDepthfirstStrategy<float, float, float, float>;
+  Parent::IndirectKernelType m_indirect_kernel = sve_fp32_nhwc_3x3_s2_output2x2_mla_depthfirst_indirect_impl;
+  Parent::DirectKernelType m_direct_kernel = sve_fp32_nhwc_3x3_s2_output2x2_mla_depthfirst_direct_impl;
 
   public:
-  typedef float return_type;
-
-  constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::SVE;
+  using return_type = float;
+  constexpr static auto vl_type = arm_gemm::VLType::SVE;
 
   constexpr static unsigned int kernel_rows = 3;
   constexpr static unsigned int kernel_cols = 3;
@@ -59,63 +56,16 @@
   constexpr static unsigned int output_rows = 2;
   constexpr static unsigned int output_cols = 2;
 
-  constexpr static unsigned int input_rows = 5;
-  constexpr static unsigned int input_cols = 5;
-
-  sve_fp32_nhwc_3x3_s2_output2x2_mla_depthfirst(const CPUInfo *) {}
+  sve_fp32_nhwc_3x3_s2_output2x2_mla_depthfirst(const CPUInfo *)
+  : DepthwiseDepthfirstStrategy<float, float, float, float>(2, 3, 2) {}
 
   arm_gemm::VLType get_vl_type(void) const override { return vl_type; }
 
-  unsigned int get_kernel_rows(void) const override { return kernel_rows; }
-  unsigned int get_kernel_cols(void) const override { return kernel_cols; }
-
-  unsigned int get_stride_rows(void) const override { return stride_rows; }
-  unsigned int get_stride_cols(void) const override { return stride_cols; }
-
-  unsigned int get_output_rows(void) const override { return output_rows; }
-  unsigned int get_output_cols(void) const override { return output_cols; }
-
-  unsigned int get_input_rows(void) const override { return input_rows; }
-  unsigned int get_input_cols(void) const override { return input_cols; }
-
-  void indirect_kernel(
-    const void *const *const input_ptrs,
-    void *const *const outptrs,
-    const void *params,
-    unsigned int n_channels,
-    const void *activation_min,
-    const void *activation_max
-  ) const override
-  {
-    m_indirect_kernel(
-      reinterpret_cast<const float *const *>(input_ptrs),
-      reinterpret_cast<float *const *>(outptrs),
-      params, n_channels,
-      *static_cast<const float *>(activation_min),
-      *static_cast<const float *>(activation_max)
-    );
-  }
-
-  void direct_kernel(
-    const unsigned int n_tile_rows, const unsigned int n_tile_cols,
-    const void *inptr, int64_t ld_input_row, int64_t ld_input_col,
-    void *outptr, int64_t ld_output_row, int64_t ld_output_col,
-    const void *params, unsigned int n_channels,
-    const void *activation_min, const void *activation_max
-  ) const override
-  {
-    m_direct_kernel(
-      n_tile_rows, n_tile_cols,
-      static_cast<const float *>(inptr), ld_input_row, ld_input_col,
-      static_cast<float *>(outptr), ld_output_row, ld_output_col,
-      params, n_channels,
-      *static_cast<const float *>(activation_min),
-      *static_cast<const float *>(activation_max)
-    );
-  }
+  Parent::IndirectKernelType get_indirect_kernel() const override { return m_indirect_kernel; }
+  Parent::DirectKernelType get_direct_kernel() const override { return m_direct_kernel; }
 };
 
 }  // namespace depthwise
 }  // namespace arm_conv
 
-#endif  // __aarch64__ && defined(ARM_COMPUTE_ENABLE_SVE)
+#endif  // defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE)
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp
index 751874f..2449c96 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -28,7 +28,7 @@
 
 #pragma once
 
-#if __aarch64__ && defined(ARM_COMPUTE_ENABLE_SVE)
+#if defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE)
 
 namespace arm_conv {
 namespace depthwise {
@@ -36,19 +36,16 @@
 void sve_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst_indirect_impl(const float *const *const, float *const *const, const void *, unsigned int, const float, const float);
 void sve_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst_direct_impl(const unsigned int, const unsigned int, const float *, int64_t, int64_t, float *, int64_t, int64_t, const void *, unsigned int, const float, const float);
 
-class sve_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst : public IDepthwiseDepthfirstStrategy
+class sve_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst : public DepthwiseDepthfirstStrategy<float, float, float, float>
 {
   private:
-  typedef void (*indirect_kern_type)(const float *const *const, float *const *const, const void *, unsigned int, const float, const float);
-  indirect_kern_type m_indirect_kernel = sve_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst_indirect_impl;
-
-  typedef void (*direct_kern_type)(const unsigned int, const unsigned int, const float *, int64_t, int64_t, float *, int64_t, int64_t, const void *, unsigned int, const float, const float);
-  direct_kern_type m_direct_kernel = sve_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst_direct_impl;
+  using Parent = DepthwiseDepthfirstStrategy<float, float, float, float>;
+  Parent::IndirectKernelType m_indirect_kernel = sve_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst_indirect_impl;
+  Parent::DirectKernelType m_direct_kernel = sve_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst_direct_impl;
 
   public:
-  typedef float return_type;
-
-  constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::SVE;
+  using return_type = float;
+  constexpr static auto vl_type = arm_gemm::VLType::SVE;
 
   constexpr static unsigned int kernel_rows = 5;
   constexpr static unsigned int kernel_cols = 5;
@@ -59,63 +56,16 @@
   constexpr static unsigned int output_rows = 2;
   constexpr static unsigned int output_cols = 2;
 
-  constexpr static unsigned int input_rows = 6;
-  constexpr static unsigned int input_cols = 6;
-
-  sve_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst(const CPUInfo *) {}
+  sve_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst(const CPUInfo *)
+  : DepthwiseDepthfirstStrategy<float, float, float, float>(2, 5, 1) {}
 
   arm_gemm::VLType get_vl_type(void) const override { return vl_type; }
 
-  unsigned int get_kernel_rows(void) const override { return kernel_rows; }
-  unsigned int get_kernel_cols(void) const override { return kernel_cols; }
-
-  unsigned int get_stride_rows(void) const override { return stride_rows; }
-  unsigned int get_stride_cols(void) const override { return stride_cols; }
-
-  unsigned int get_output_rows(void) const override { return output_rows; }
-  unsigned int get_output_cols(void) const override { return output_cols; }
-
-  unsigned int get_input_rows(void) const override { return input_rows; }
-  unsigned int get_input_cols(void) const override { return input_cols; }
-
-  void indirect_kernel(
-    const void *const *const input_ptrs,
-    void *const *const outptrs,
-    const void *params,
-    unsigned int n_channels,
-    const void *activation_min,
-    const void *activation_max
-  ) const override
-  {
-    m_indirect_kernel(
-      reinterpret_cast<const float *const *>(input_ptrs),
-      reinterpret_cast<float *const *>(outptrs),
-      params, n_channels,
-      *static_cast<const float *>(activation_min),
-      *static_cast<const float *>(activation_max)
-    );
-  }
-
-  void direct_kernel(
-    const unsigned int n_tile_rows, const unsigned int n_tile_cols,
-    const void *inptr, int64_t ld_input_row, int64_t ld_input_col,
-    void *outptr, int64_t ld_output_row, int64_t ld_output_col,
-    const void *params, unsigned int n_channels,
-    const void *activation_min, const void *activation_max
-  ) const override
-  {
-    m_direct_kernel(
-      n_tile_rows, n_tile_cols,
-      static_cast<const float *>(inptr), ld_input_row, ld_input_col,
-      static_cast<float *>(outptr), ld_output_row, ld_output_col,
-      params, n_channels,
-      *static_cast<const float *>(activation_min),
-      *static_cast<const float *>(activation_max)
-    );
-  }
+  Parent::IndirectKernelType get_indirect_kernel() const override { return m_indirect_kernel; }
+  Parent::DirectKernelType get_direct_kernel() const override { return m_direct_kernel; }
 };
 
 }  // namespace depthwise
 }  // namespace arm_conv
 
-#endif  // __aarch64__ && defined(ARM_COMPUTE_ENABLE_SVE)
+#endif  // defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE)
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_fp32_nhwc_generic_output9_mla_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_fp32_nhwc_generic_output9_mla_depthfirst.hpp
index bd071d3..62faca9 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_fp32_nhwc_generic_output9_mla_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_fp32_nhwc_generic_output9_mla_depthfirst.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -35,22 +35,14 @@
 
 void sve_fp32_nhwc_generic_output9_mla_depthfirst_impl(const float *const *const, float *const *const, const void *, const void *, const unsigned int, const unsigned int, const float, const float);
 
-struct sve_fp32_nhwc_generic_output9_mla_depthfirst
+class sve_fp32_nhwc_generic_output9_mla_depthfirst : public GenericDepthfirstKernelStrategy<float, float, float, float>
 {
-  typedef float bias_type;
-  typedef float input_type;
-  typedef float weight_type;
-  typedef float return_type;
+  KernelType kernel = sve_fp32_nhwc_generic_output9_mla_depthfirst_impl;
 
-  typedef void (*kern_type)(const float *const *const, float *const *const, const void *, const void *, const unsigned int, const unsigned int, const float, const float);
+  public:
+  sve_fp32_nhwc_generic_output9_mla_depthfirst(const CPUInfo *) : GenericDepthfirstKernelStrategy<float, float, float, float>(9, arm_gemm::VLType::SVE) {}
 
-  constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::SVE;
-
-  constexpr static unsigned int n_output_points = 9;
-
-  kern_type kernel = sve_fp32_nhwc_generic_output9_mla_depthfirst_impl;
-
-  sve_fp32_nhwc_generic_output9_mla_depthfirst(const CPUInfo *) {}
+  KernelType get_kernel() const override { return kernel; }
 };
 
 }  // namespace depthwise
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_fp32_packed_to_nhwc_3x3_s2_with_multiplier_output3x3_mla_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_fp32_packed_to_nhwc_3x3_s2_with_multiplier_output3x3_mla_depthfirst.hpp
index 563f0fc..8640343 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_fp32_packed_to_nhwc_3x3_s2_with_multiplier_output3x3_mla_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_fp32_packed_to_nhwc_3x3_s2_with_multiplier_output3x3_mla_depthfirst.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -35,33 +35,24 @@
 
 void sve_fp32_packed_to_nhwc_3x3_s2_with_multiplier_output3x3_mla_depthfirst_impl(const float *const *const, float *const *const, const void *, const unsigned int, const float, const float);
 
-struct sve_fp32_packed_to_nhwc_3x3_s2_with_multiplier_output3x3_mla_depthfirst
+struct sve_fp32_packed_to_nhwc_3x3_s2_with_multiplier_output3x3_mla_depthfirst : DepthfirstMultiplierStrategy<float, float, float, float>
 {
-  typedef float bias_type;
-  typedef float input_type;
-  typedef float weight_type;
-  typedef float return_type;
-
-  typedef void (*kern_type)(const float *const *const, float *const *const, const void *, const unsigned int, const float, const float);
-
-  constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::SVE;
-
+  using Parent = DepthfirstMultiplierStrategy<float, float, float, float>;
   constexpr static unsigned int kernel_rows = 3;
   constexpr static unsigned int kernel_cols = 3;
 
   constexpr static unsigned int stride_rows = 2;
   constexpr static unsigned int stride_cols = 2;
 
-  constexpr static unsigned int output_rows = 3;
-  constexpr static unsigned int output_cols = 3;
+  sve_fp32_packed_to_nhwc_3x3_s2_with_multiplier_output3x3_mla_depthfirst(const CPUInfo *)
+  : Parent(3, 3, kernel_rows, kernel_cols, stride_rows, stride_cols)
+  {
+  }
 
-  constexpr static unsigned int input_rows = 7;
-  constexpr static unsigned int input_cols = 7;
-  constexpr static unsigned int input_col_quads = 2;
+  arm_gemm::VLType get_vl_type() const override { return arm_gemm::VLType::SVE; }
 
-  kern_type kernel = sve_fp32_packed_to_nhwc_3x3_s2_with_multiplier_output3x3_mla_depthfirst_impl;
-
-  sve_fp32_packed_to_nhwc_3x3_s2_with_multiplier_output3x3_mla_depthfirst(const CPUInfo *) {}
+  Parent::KernelType kernel = sve_fp32_packed_to_nhwc_3x3_s2_with_multiplier_output3x3_mla_depthfirst_impl;
+  Parent::KernelType get_kernel(void) const override { return kernel; }
 };
 
 }  // namespace depthwise
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_fp32_packed_to_nhwc_5x5_s1_with_multiplier_output2x4_mla_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_fp32_packed_to_nhwc_5x5_s1_with_multiplier_output2x4_mla_depthfirst.hpp
index e9378c2..a4ee87c 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_fp32_packed_to_nhwc_5x5_s1_with_multiplier_output2x4_mla_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_fp32_packed_to_nhwc_5x5_s1_with_multiplier_output2x4_mla_depthfirst.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -35,33 +35,24 @@
 
 void sve_fp32_packed_to_nhwc_5x5_s1_with_multiplier_output2x4_mla_depthfirst_impl(const float *const *const, float *const *const, const void *, const unsigned int, const float, const float);
 
-struct sve_fp32_packed_to_nhwc_5x5_s1_with_multiplier_output2x4_mla_depthfirst
+struct sve_fp32_packed_to_nhwc_5x5_s1_with_multiplier_output2x4_mla_depthfirst : DepthfirstMultiplierStrategy<float, float, float, float>
 {
-  typedef float bias_type;
-  typedef float input_type;
-  typedef float weight_type;
-  typedef float return_type;
-
-  typedef void (*kern_type)(const float *const *const, float *const *const, const void *, const unsigned int, const float, const float);
-
-  constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::SVE;
-
+  using Parent = DepthfirstMultiplierStrategy<float, float, float, float>;
   constexpr static unsigned int kernel_rows = 5;
   constexpr static unsigned int kernel_cols = 5;
 
   constexpr static unsigned int stride_rows = 1;
   constexpr static unsigned int stride_cols = 1;
 
-  constexpr static unsigned int output_rows = 2;
-  constexpr static unsigned int output_cols = 4;
+  sve_fp32_packed_to_nhwc_5x5_s1_with_multiplier_output2x4_mla_depthfirst(const CPUInfo *)
+  : Parent(2, 4, kernel_rows, kernel_cols, stride_rows, stride_cols)
+  {
+  }
 
-  constexpr static unsigned int input_rows = 6;
-  constexpr static unsigned int input_cols = 8;
-  constexpr static unsigned int input_col_quads = 2;
+  arm_gemm::VLType get_vl_type() const override { return arm_gemm::VLType::SVE; }
 
-  kern_type kernel = sve_fp32_packed_to_nhwc_5x5_s1_with_multiplier_output2x4_mla_depthfirst_impl;
-
-  sve_fp32_packed_to_nhwc_5x5_s1_with_multiplier_output2x4_mla_depthfirst(const CPUInfo *) {}
+  Parent::KernelType kernel = sve_fp32_packed_to_nhwc_5x5_s1_with_multiplier_output2x4_mla_depthfirst_impl;
+  Parent::KernelType get_kernel(void) const override { return kernel; }
 };
 
 }  // namespace depthwise
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_fp32_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_fp32_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst.hpp
index 6849e56..e1f0b50 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_fp32_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_fp32_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -28,35 +28,25 @@
 
 #pragma once
 
-#if defined(ARM_COMPUTE_ENABLE_SVE)
+#if defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE)
 
 namespace arm_conv {
 namespace depthwise {
 
 void sve_fp32_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst_impl(const float *const *const, float *const *const, const float *, const float *, const unsigned int, const unsigned int, const float, const float);
 
-struct sve_fp32_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst
+struct sve_fp32_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst : GenericDepthfirstMultiplierKernelStrategy<float, float, float, float>
 {
-  typedef float bias_type;
-  typedef float input_type;
-  typedef float weight_type;
-  typedef float return_type;
-
-  typedef void (*kern_type)(const float *const *const, float *const *const, const float *, const float *, const unsigned int, const unsigned int, const float, const float);
-
-  constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::SVE;
-
-  constexpr static unsigned int output_rows(void) { return 2; };
-  constexpr static unsigned int output_cols(void) { return 8; };
-
-  constexpr static unsigned int output_col_regs(void) { return 2; };
-
-  kern_type kernel = sve_fp32_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst_impl;
-
-  sve_fp32_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst(const CPUInfo *) {}
+  using Parent = GenericDepthfirstMultiplierKernelStrategy<float, float, float, float>;
+  sve_fp32_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst(const CPUInfo *)
+  : Parent(2, 8, arm_gemm::VLType::SVE)
+  {
+  }
+  Parent::KernelType kernel = sve_fp32_packed_to_nhwc_generic_with_multiplier_output2x8_mla_depthfirst_impl;
+  Parent::KernelType get_kernel(void) const override { return kernel; }
 };
 
 }  // namespace depthwise
 }  // namespace arm_conv
 
-#endif  // defined(ARM_COMPUTE_ENABLE_SVE)
+#endif  // defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE)
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_s8q_nhwc_3x3_s1_output2x2_dot_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_s8q_nhwc_3x3_s1_output2x2_dot_depthfirst.hpp
index 39974fd..4e2ee43 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_s8q_nhwc_3x3_s1_output2x2_dot_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_s8q_nhwc_3x3_s1_output2x2_dot_depthfirst.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -34,39 +34,40 @@
 namespace arm_conv {
 namespace depthwise {
 
-void sve_s8q_nhwc_3x3_s1_output2x2_dot_depthfirst_impl(const int8_t *const *, int8_t *const *, const void *, uint64_t, const arm_gemm::Requantize32&);
+void sve_s8q_nhwc_3x3_s1_output2x2_dot_depthfirst_impl(unsigned int, const int8_t *const *, const int8_t *, const int32_t *, const arm_gemm::Requantize32&, const int32_t *, const int32_t *, int8_t *const *);
 
-struct sve_s8q_nhwc_3x3_s1_output2x2_dot_depthfirst
+class sve_s8q_nhwc_3x3_s1_output2x2_dot_depthfirst : public DepthwiseDepthfirstStrategy<int8_t, int8_t, int8_t, int32_t>
 {
-  typedef int32_t bias_type;
-  typedef int8_t input_type;
-  typedef int8_t weight_type;
-  typedef int8_t return_type;
+  using Parent = DepthwiseDepthfirstStrategy<int8_t, int8_t, int8_t, int32_t>;
 
-  constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::SVE;
-
-  typedef void (*kern_type)(const int8_t *const *, int8_t *const *, const void *, uint64_t, const arm_gemm::Requantize32&);
-  typedef void (*parameter_packing_fn)(unsigned int, void *, const int32_t *, const int8_t *, const arm_gemm::Requantize32 &, size_t, size_t);
-  typedef size_t (*parameter_sizing_fn)(const DepthwiseArgs &);
-
+  public:
   constexpr static unsigned int kernel_rows = 3;
   constexpr static unsigned int kernel_cols = 3;
 
   constexpr static unsigned int stride_rows = 1;
   constexpr static unsigned int stride_cols = 1;
 
-  constexpr static unsigned int output_rows = 2;
-  constexpr static unsigned int output_cols = 2;
+  sve_s8q_nhwc_3x3_s1_output2x2_dot_depthfirst(const CPUInfo *) : Parent(2, 2, 3, 3, 1, 1) {}
 
-  constexpr static unsigned int input_rows = 4;
-  constexpr static unsigned int input_cols = 4;
+  arm_gemm::VLType get_vl_type(void) const override { return arm_gemm::VLType::SVE; }
 
-  constexpr static parameter_packing_fn pack_parameters = interleave_sve_s8q_3x3_dot::pack_parameters;
-  constexpr static parameter_sizing_fn get_packed_size = interleave_sve_s8q_3x3_dot::get_packed_size;
+  Parent::KernelType kernel = sve_s8q_nhwc_3x3_s1_output2x2_dot_depthfirst_impl;
+  Parent::KernelType get_kernel(void) const override { return kernel; }
+  size_t get_storage_size(const DepthwiseArgs &args) const override
+  {
+    return interleave_sve_s8q_3x3_dot::get_packed_size(args);
+  }
 
-  kern_type kernel = sve_s8q_nhwc_3x3_s1_output2x2_dot_depthfirst_impl;
-
-  sve_s8q_nhwc_3x3_s1_output2x2_dot_depthfirst(const CPUInfo *) {}
+  void pack_parameters(
+    const DepthwiseArgs &args, void *buffer, const void *biases, const arm_gemm::Requantize32 &qp,
+    const void *weights, size_t ld_weight_col, size_t ld_weight_row
+  ) const override
+  {
+    interleave_sve_s8q_3x3_dot::pack_parameters(
+      args.input_channels, buffer, reinterpret_cast<const int32_t *>(biases),
+      reinterpret_cast<const int8_t *>(weights), qp, ld_weight_col, ld_weight_row
+    );
+  }
 };
 
 }  // namespace depthwise
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_s8q_nhwc_3x3_s1_output2x2_dot_depthfirst/generic.cpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_s8q_nhwc_3x3_s1_output2x2_dot_depthfirst/generic.cpp
index 8e9e5f4..8008037 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_s8q_nhwc_3x3_s1_output2x2_dot_depthfirst/generic.cpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_s8q_nhwc_3x3_s1_output2x2_dot_depthfirst/generic.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -30,7 +30,15 @@
 namespace arm_conv {
 namespace depthwise {
 
-void sve_s8q_nhwc_3x3_s1_output2x2_dot_depthfirst_impl(const int8_t *const *const inptrs, int8_t *const *const outptrs, const void *params, const uint64_t n_channels, const arm_gemm::Requantize32& qp)
+void sve_s8q_nhwc_3x3_s1_output2x2_dot_depthfirst_impl(
+  const unsigned int n_channels,
+  const int8_t *const *const inptrs,
+  const int8_t *params,
+  const int32_t *,  // Bias, should be wrapped into the parameters
+  const arm_gemm::Requantize32& qp,
+  const int32_t *, const int32_t *,  // Requant parameters, also wrapped
+  int8_t *const *const outptrs
+)
 {
   __asm__ __volatile__(
     "ldp x11, x10, [%x[inptrs], #0x0]\n"
@@ -446,7 +454,7 @@
     "b.any 1b\n"
     "addvl SP, SP, #8\n"
     : [params] "+&r" (params)
-    : [inptrs] "r" (inptrs), [n_channels] "r" (n_channels), [offsetof_Requantize32_b_offset] "I" (offsetof(arm_gemm::Requantize32, b_offset)), [offsetof_Requantize32_c_offset] "I" (offsetof(arm_gemm::Requantize32, c_offset)), [offsetof_Requantize32_maxval] "I" (offsetof(arm_gemm::Requantize32, maxval)), [offsetof_Requantize32_minval] "I" (offsetof(arm_gemm::Requantize32, minval)), [outptrs] "r" (outptrs), [qp] "r" (&qp)
+    : [inptrs] "r" (inptrs), [n_channels] "r" ((long unsigned int) n_channels), [offsetof_Requantize32_b_offset] "I" (offsetof(arm_gemm::Requantize32, b_offset)), [offsetof_Requantize32_c_offset] "I" (offsetof(arm_gemm::Requantize32, c_offset)), [offsetof_Requantize32_maxval] "I" (offsetof(arm_gemm::Requantize32, maxval)), [offsetof_Requantize32_minval] "I" (offsetof(arm_gemm::Requantize32, minval)), [outptrs] "r" (outptrs), [qp] "r" (&qp)
     : "cc", "memory", "p0", "p1", "p2", "x9", "x10", "x11", "x19", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"
   );
 }
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_s8q_nhwc_3x3_s1_output2x2_mla_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_s8q_nhwc_3x3_s1_output2x2_mla_depthfirst.hpp
index f788829..3e97651 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_s8q_nhwc_3x3_s1_output2x2_mla_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_s8q_nhwc_3x3_s1_output2x2_mla_depthfirst.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -29,47 +29,35 @@
 
 #pragma once
 
-#if defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE) && defined(ARM_COMPUTE_ENABLE_SVE2)
+#if defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE)
 
 namespace arm_conv {
 namespace depthwise {
 
 void sve_s8q_nhwc_3x3_s1_output2x2_mla_depthfirst_impl(unsigned int, const int8_t *const *, const int8_t *, const int32_t *, const arm_gemm::Requantize32 &, const int32_t *, const int32_t *, int8_t *const *);
 
-struct sve_s8q_nhwc_3x3_s1_output2x2_mla_depthfirst
+class sve_s8q_nhwc_3x3_s1_output2x2_mla_depthfirst : public DepthwiseDepthfirstStrategy<int8_t, int8_t, int8_t, int32_t>
 {
-  typedef int32_t bias_type;
-  typedef int8_t input_type;
-  typedef int8_t weight_type;
-  typedef int8_t return_type;
+  using Parent = DepthwiseDepthfirstStrategy<int8_t, int8_t, int8_t, int32_t>;
 
-  constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::SVE;
-
-  typedef void (*kern_type)(unsigned int, const int8_t *const *, const int8_t *, const int32_t *, const arm_gemm::Requantize32 &, const int32_t *, const int32_t *, int8_t *const *);
-  typedef void (*parameter_packing_fn)(unsigned int, void *, const int8_t *, size_t, size_t);
-  typedef size_t (*parameter_sizing_fn)(const DepthwiseArgs &);
-
+  public:
   constexpr static unsigned int kernel_rows = 3;
   constexpr static unsigned int kernel_cols = 3;
 
   constexpr static unsigned int stride_rows = 1;
   constexpr static unsigned int stride_cols = 1;
 
-  constexpr static unsigned int output_rows = 2;
-  constexpr static unsigned int output_cols = 2;
+  arm_gemm::VLType get_vl_type(void) const override { return arm_gemm::VLType::SVE; }
+  unsigned int get_accumulator_depth_vl(void) const override { return 2; }
 
-  constexpr static unsigned int input_rows = 4;
-  constexpr static unsigned int input_cols = 4;
+  sve_s8q_nhwc_3x3_s1_output2x2_mla_depthfirst(const CPUInfo *) : Parent(2, 2, 3, 3, 1, 1) {}
 
-  constexpr static parameter_packing_fn pack_parameters = interleave_sve_s8q_3x3_mla::pack_parameters;
-  constexpr static parameter_sizing_fn get_packed_size = interleave_sve_s8q_3x3_mla::get_packed_size;
+  Parent::KernelType kernel = sve_s8q_nhwc_3x3_s1_output2x2_mla_depthfirst_impl;
 
-  kern_type kernel = sve_s8q_nhwc_3x3_s1_output2x2_mla_depthfirst_impl;
-
-  sve_s8q_nhwc_3x3_s1_output2x2_mla_depthfirst(const CPUInfo *) {}
+  Parent::KernelType get_kernel(void) const override { return kernel; }
 };
 
 }  // namespace depthwise
 }  // namespace arm_conv
 
-#endif  // defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE) && defined(ARM_COMPUTE_ENABLE_SVE2)
+#endif  // defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE)
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_s8q_nhwc_3x3_s1_output2x2_mla_depthfirst/generic.cpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_s8q_nhwc_3x3_s1_output2x2_mla_depthfirst/generic.cpp
index 8738796..3583308 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_s8q_nhwc_3x3_s1_output2x2_mla_depthfirst/generic.cpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_s8q_nhwc_3x3_s1_output2x2_mla_depthfirst/generic.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -27,7 +27,7 @@
 #include <cstddef>
 #include <cstdint>
 
-#if defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE) && defined(ARM_COMPUTE_ENABLE_SVE2)
+#if defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE)
 
 namespace arm_conv {
 namespace depthwise {
@@ -415,4 +415,4 @@
 }  // namespace depthwise
 }  // namespace arm_conv
 
-#endif  // defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE) && defined(ARM_COMPUTE_ENABLE_SVE2)
+#endif  // defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE)
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_s8q_nhwc_3x3_s2_output2x2_mla_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_s8q_nhwc_3x3_s2_output2x2_mla_depthfirst.hpp
index 5c2b4f6..78bcd14 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_s8q_nhwc_3x3_s2_output2x2_mla_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_s8q_nhwc_3x3_s2_output2x2_mla_depthfirst.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -29,47 +29,35 @@
 
 #pragma once
 
-#if defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE) && defined(ARM_COMPUTE_ENABLE_SVE2)
+#if defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE)
 
 namespace arm_conv {
 namespace depthwise {
 
 void sve_s8q_nhwc_3x3_s2_output2x2_mla_depthfirst_impl(unsigned int, const int8_t *const *, const int8_t *, const int32_t *, const arm_gemm::Requantize32 &, const int32_t *, const int32_t *, int8_t *const *);
 
-struct sve_s8q_nhwc_3x3_s2_output2x2_mla_depthfirst
+class sve_s8q_nhwc_3x3_s2_output2x2_mla_depthfirst : public DepthwiseDepthfirstStrategy<int8_t, int8_t, int8_t, int32_t>
 {
-  typedef int32_t bias_type;
-  typedef int8_t input_type;
-  typedef int8_t weight_type;
-  typedef int8_t return_type;
+  using Parent = DepthwiseDepthfirstStrategy<int8_t, int8_t, int8_t, int32_t>;
 
-  constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::SVE;
-
-  typedef void (*kern_type)(unsigned int, const int8_t *const *, const int8_t *, const int32_t *, const arm_gemm::Requantize32 &, const int32_t *, const int32_t *, int8_t *const *);
-  typedef void (*parameter_packing_fn)(unsigned int, void *, const int8_t *, size_t, size_t);
-  typedef size_t (*parameter_sizing_fn)(const DepthwiseArgs &);
-
+  public:
   constexpr static unsigned int kernel_rows = 3;
   constexpr static unsigned int kernel_cols = 3;
 
   constexpr static unsigned int stride_rows = 2;
   constexpr static unsigned int stride_cols = 2;
 
-  constexpr static unsigned int output_rows = 2;
-  constexpr static unsigned int output_cols = 2;
+  arm_gemm::VLType get_vl_type(void) const override { return arm_gemm::VLType::SVE; }
+  unsigned int get_accumulator_depth_vl(void) const override { return 2; }
 
-  constexpr static unsigned int input_rows = 5;
-  constexpr static unsigned int input_cols = 5;
+  sve_s8q_nhwc_3x3_s2_output2x2_mla_depthfirst(const CPUInfo *) : Parent(2, 2, 3, 3, 2, 2) {}
 
-  constexpr static parameter_packing_fn pack_parameters = interleave_sve_s8q_3x3_mla::pack_parameters;
-  constexpr static parameter_sizing_fn get_packed_size = interleave_sve_s8q_3x3_mla::get_packed_size;
+  Parent::KernelType kernel = sve_s8q_nhwc_3x3_s2_output2x2_mla_depthfirst_impl;
 
-  kern_type kernel = sve_s8q_nhwc_3x3_s2_output2x2_mla_depthfirst_impl;
-
-  sve_s8q_nhwc_3x3_s2_output2x2_mla_depthfirst(const CPUInfo *) {}
+  Parent::KernelType get_kernel(void) const override { return kernel; }
 };
 
 }  // namespace depthwise
 }  // namespace arm_conv
 
-#endif  // defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE) && defined(ARM_COMPUTE_ENABLE_SVE2)
+#endif  // defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE)
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_s8q_nhwc_3x3_s2_output2x2_mla_depthfirst/generic.cpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_s8q_nhwc_3x3_s2_output2x2_mla_depthfirst/generic.cpp
index b4a1026..ba8c1fd 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_s8q_nhwc_3x3_s2_output2x2_mla_depthfirst/generic.cpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_s8q_nhwc_3x3_s2_output2x2_mla_depthfirst/generic.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -27,7 +27,7 @@
 #include <cstddef>
 #include <cstdint>
 
-#if defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE) && defined(ARM_COMPUTE_ENABLE_SVE2)
+#if defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE)
 
 namespace arm_conv {
 namespace depthwise {
@@ -456,4 +456,4 @@
 }  // namespace depthwise
 }  // namespace arm_conv
 
-#endif  // defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE) && defined(ARM_COMPUTE_ENABLE_SVE2)
+#endif  // defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE)
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_s8q_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_s8q_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp
index 948c5ad..41ecd52 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_s8q_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_s8q_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -29,47 +29,35 @@
 
 #pragma once
 
-#if defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE) && defined(ARM_COMPUTE_ENABLE_SVE2)
+#if defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE)
 
 namespace arm_conv {
 namespace depthwise {
 
 void sve_s8q_nhwc_5x5_s1_output2x2_mla_depthfirst_impl(unsigned int, const int8_t *const *, const int8_t *, const int32_t *, const arm_gemm::Requantize32 &, const int32_t *, const int32_t *, int8_t *const *);
 
-struct sve_s8q_nhwc_5x5_s1_output2x2_mla_depthfirst
+class sve_s8q_nhwc_5x5_s1_output2x2_mla_depthfirst : public DepthwiseDepthfirstStrategy<int8_t, int8_t, int8_t, int32_t>
 {
-  typedef int32_t bias_type;
-  typedef int8_t input_type;
-  typedef int8_t weight_type;
-  typedef int8_t return_type;
+  using Parent = DepthwiseDepthfirstStrategy<int8_t, int8_t, int8_t, int32_t>;
 
-  constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::SVE;
-
-  typedef void (*kern_type)(unsigned int, const int8_t *const *, const int8_t *, const int32_t *, const arm_gemm::Requantize32 &, const int32_t *, const int32_t *, int8_t *const *);
-  typedef void (*parameter_packing_fn)(unsigned int, void *, const int8_t *, size_t, size_t);
-  typedef size_t (*parameter_sizing_fn)(const DepthwiseArgs &);
-
+  public:
   constexpr static unsigned int kernel_rows = 5;
   constexpr static unsigned int kernel_cols = 5;
 
   constexpr static unsigned int stride_rows = 1;
   constexpr static unsigned int stride_cols = 1;
 
-  constexpr static unsigned int output_rows = 2;
-  constexpr static unsigned int output_cols = 2;
+  arm_gemm::VLType get_vl_type(void) const override { return arm_gemm::VLType::SVE; }
+  unsigned int get_accumulator_depth_vl(void) const override { return 2; }
 
-  constexpr static unsigned int input_rows = 6;
-  constexpr static unsigned int input_cols = 6;
+  sve_s8q_nhwc_5x5_s1_output2x2_mla_depthfirst(const CPUInfo *) : Parent(2, 2, 5, 5, 1, 1) {}
 
-  constexpr static parameter_packing_fn pack_parameters = interleave_sve_s8q_5x5_mla::pack_parameters;
-  constexpr static parameter_sizing_fn get_packed_size = interleave_sve_s8q_5x5_mla::get_packed_size;
+  Parent::KernelType kernel = sve_s8q_nhwc_5x5_s1_output2x2_mla_depthfirst_impl;
 
-  kern_type kernel = sve_s8q_nhwc_5x5_s1_output2x2_mla_depthfirst_impl;
-
-  sve_s8q_nhwc_5x5_s1_output2x2_mla_depthfirst(const CPUInfo *) {}
+  Parent::KernelType get_kernel(void) const override { return kernel; }
 };
 
 }  // namespace depthwise
 }  // namespace arm_conv
 
-#endif  // defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE) && defined(ARM_COMPUTE_ENABLE_SVE2)
+#endif  // defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE)
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_s8q_nhwc_5x5_s1_output2x2_mla_depthfirst/generic.cpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_s8q_nhwc_5x5_s1_output2x2_mla_depthfirst/generic.cpp
index 565c145..4733c89 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_s8q_nhwc_5x5_s1_output2x2_mla_depthfirst/generic.cpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_s8q_nhwc_5x5_s1_output2x2_mla_depthfirst/generic.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -27,7 +27,7 @@
 #include <cstddef>
 #include <cstdint>
 
-#if defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE) && defined(ARM_COMPUTE_ENABLE_SVE2)
+#if defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE)
 
 namespace arm_conv {
 namespace depthwise {
@@ -657,4 +657,4 @@
 }  // namespace depthwise
 }  // namespace arm_conv
 
-#endif  // defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE) && defined(ARM_COMPUTE_ENABLE_SVE2)
+#endif  // defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE)
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_s8q_packed_to_nhwc_3x3_s2_with_multiplier_output2x4_dot_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_s8q_packed_to_nhwc_3x3_s2_with_multiplier_output2x4_dot_depthfirst.hpp
index 176c4f8..2e8c201 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_s8q_packed_to_nhwc_3x3_s2_with_multiplier_output2x4_dot_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_s8q_packed_to_nhwc_3x3_s2_with_multiplier_output2x4_dot_depthfirst.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -35,33 +35,24 @@
 
 void sve_s8q_packed_to_nhwc_3x3_s2_with_multiplier_output2x4_dot_depthfirst_impl(const int8_t *const *const, int8_t *const *const, const void *, unsigned int, const arm_gemm::Requantize32&);
 
-struct sve_s8q_packed_to_nhwc_3x3_s2_with_multiplier_output2x4_dot_depthfirst
+struct sve_s8q_packed_to_nhwc_3x3_s2_with_multiplier_output2x4_dot_depthfirst : DepthfirstMultiplierStrategy<int8_t, int8_t, int8_t, int32_t>
 {
-  typedef int32_t bias_type;
-  typedef int8_t input_type;
-  typedef int8_t weight_type;
-  typedef int8_t return_type;
-
-  typedef void (*kern_type)(const int8_t *const *const, int8_t *const *const, const void *, unsigned int, const arm_gemm::Requantize32&);
-
-  constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::SVE;
-
+  using Parent = DepthfirstMultiplierStrategy<int8_t, int8_t, int8_t, int32_t>;
   constexpr static unsigned int kernel_rows = 3;
   constexpr static unsigned int kernel_cols = 3;
 
   constexpr static unsigned int stride_rows = 2;
   constexpr static unsigned int stride_cols = 2;
 
-  constexpr static unsigned int output_rows = 2;
-  constexpr static unsigned int output_cols = 4;
+  sve_s8q_packed_to_nhwc_3x3_s2_with_multiplier_output2x4_dot_depthfirst(const CPUInfo *)
+  : Parent(2, 4, kernel_rows, kernel_cols, stride_rows, stride_cols)
+  {
+  }
 
-  constexpr static unsigned int input_rows = 5;
-  constexpr static unsigned int input_cols = 9;
-  constexpr static unsigned int input_col_quads = 1;
+  arm_gemm::VLType get_vl_type() const override { return arm_gemm::VLType::SVE; }
 
-  kern_type kernel = sve_s8q_packed_to_nhwc_3x3_s2_with_multiplier_output2x4_dot_depthfirst_impl;
-
-  sve_s8q_packed_to_nhwc_3x3_s2_with_multiplier_output2x4_dot_depthfirst(const CPUInfo *) {}
+  Parent::KernelType kernel = sve_s8q_packed_to_nhwc_3x3_s2_with_multiplier_output2x4_dot_depthfirst_impl;
+  Parent::KernelType get_kernel(void) const override { return kernel; }
 };
 
 }  // namespace depthwise
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_s8q_packed_to_nhwc_5x5_s1_with_multiplier_output4x2_dot_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_s8q_packed_to_nhwc_5x5_s1_with_multiplier_output4x2_dot_depthfirst.hpp
index 10eee34..4874fb9 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_s8q_packed_to_nhwc_5x5_s1_with_multiplier_output4x2_dot_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_s8q_packed_to_nhwc_5x5_s1_with_multiplier_output4x2_dot_depthfirst.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -35,33 +35,24 @@
 
 void sve_s8q_packed_to_nhwc_5x5_s1_with_multiplier_output4x2_dot_depthfirst_impl(const int8_t *const *const, int8_t *const *const, const void *, unsigned int, const arm_gemm::Requantize32&);
 
-struct sve_s8q_packed_to_nhwc_5x5_s1_with_multiplier_output4x2_dot_depthfirst
+struct sve_s8q_packed_to_nhwc_5x5_s1_with_multiplier_output4x2_dot_depthfirst : DepthfirstMultiplierStrategy<int8_t, int8_t, int8_t, int32_t>
 {
-  typedef int32_t bias_type;
-  typedef int8_t input_type;
-  typedef int8_t weight_type;
-  typedef int8_t return_type;
-
-  typedef void (*kern_type)(const int8_t *const *const, int8_t *const *const, const void *, unsigned int, const arm_gemm::Requantize32&);
-
-  constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::SVE;
-
+  using Parent = DepthfirstMultiplierStrategy<int8_t, int8_t, int8_t, int32_t>;
   constexpr static unsigned int kernel_rows = 5;
   constexpr static unsigned int kernel_cols = 5;
 
   constexpr static unsigned int stride_rows = 1;
   constexpr static unsigned int stride_cols = 1;
 
-  constexpr static unsigned int output_rows = 4;
-  constexpr static unsigned int output_cols = 2;
+  sve_s8q_packed_to_nhwc_5x5_s1_with_multiplier_output4x2_dot_depthfirst(const CPUInfo *)
+  : Parent(4, 2, kernel_rows, kernel_cols, stride_rows, stride_cols)
+  {
+  }
 
-  constexpr static unsigned int input_rows = 8;
-  constexpr static unsigned int input_cols = 6;
-  constexpr static unsigned int input_col_quads = 1;
+  arm_gemm::VLType get_vl_type() const override { return arm_gemm::VLType::SVE; }
 
-  kern_type kernel = sve_s8q_packed_to_nhwc_5x5_s1_with_multiplier_output4x2_dot_depthfirst_impl;
-
-  sve_s8q_packed_to_nhwc_5x5_s1_with_multiplier_output4x2_dot_depthfirst(const CPUInfo *) {}
+  Parent::KernelType kernel = sve_s8q_packed_to_nhwc_5x5_s1_with_multiplier_output4x2_dot_depthfirst_impl;
+  Parent::KernelType get_kernel(void) const override { return kernel; }
 };
 
 }  // namespace depthwise
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_s8qs_nhwc_3x3_s1_output2x2_dot_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_s8qs_nhwc_3x3_s1_output2x2_dot_depthfirst.hpp
index b5c6e98..0d185fc 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_s8qs_nhwc_3x3_s1_output2x2_dot_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_s8qs_nhwc_3x3_s1_output2x2_dot_depthfirst.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -34,39 +34,40 @@
 namespace arm_conv {
 namespace depthwise {
 
-void sve_s8qs_nhwc_3x3_s1_output2x2_dot_depthfirst_impl(const int8_t *const *, int8_t *const *, const void *, uint64_t, const arm_gemm::Requantize32&);
+void sve_s8qs_nhwc_3x3_s1_output2x2_dot_depthfirst_impl(unsigned int, const int8_t *const *, const int8_t *, const int32_t *, const arm_gemm::Requantize32&, const int32_t *, const int32_t *, int8_t *const *);
 
-struct sve_s8qs_nhwc_3x3_s1_output2x2_dot_depthfirst
+class sve_s8qs_nhwc_3x3_s1_output2x2_dot_depthfirst : public DepthwiseDepthfirstStrategy<int8_t, int8_t, int8_t, int32_t>
 {
-  typedef int32_t bias_type;
-  typedef int8_t input_type;
-  typedef int8_t weight_type;
-  typedef int8_t return_type;
+  using Parent = DepthwiseDepthfirstStrategy<int8_t, int8_t, int8_t, int32_t>;
 
-  constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::SVE;
-
-  typedef void (*kern_type)(const int8_t *const *, int8_t *const *, const void *, uint64_t, const arm_gemm::Requantize32&);
-  typedef void (*parameter_packing_fn)(unsigned int, void *, const int32_t *, const int8_t *, const arm_gemm::Requantize32 &, size_t, size_t);
-  typedef size_t (*parameter_sizing_fn)(const DepthwiseArgs &);
-
+  public:
   constexpr static unsigned int kernel_rows = 3;
   constexpr static unsigned int kernel_cols = 3;
 
   constexpr static unsigned int stride_rows = 1;
   constexpr static unsigned int stride_cols = 1;
 
-  constexpr static unsigned int output_rows = 2;
-  constexpr static unsigned int output_cols = 2;
+  sve_s8qs_nhwc_3x3_s1_output2x2_dot_depthfirst(const CPUInfo *) : Parent(2, 2, 3, 3, 1, 1) {}
 
-  constexpr static unsigned int input_rows = 4;
-  constexpr static unsigned int input_cols = 4;
+  arm_gemm::VLType get_vl_type(void) const override { return arm_gemm::VLType::SVE; }
 
-  constexpr static parameter_packing_fn pack_parameters = interleave_sve_s8q_3x3_dot::pack_parameters;
-  constexpr static parameter_sizing_fn get_packed_size = interleave_sve_s8q_3x3_dot::get_packed_size;
+  Parent::KernelType kernel = sve_s8qs_nhwc_3x3_s1_output2x2_dot_depthfirst_impl;
+  Parent::KernelType get_kernel(void) const override { return kernel; }
+  size_t get_storage_size(const DepthwiseArgs &args) const override
+  {
+    return interleave_sve_s8q_3x3_dot::get_packed_size(args);
+  }
 
-  kern_type kernel = sve_s8qs_nhwc_3x3_s1_output2x2_dot_depthfirst_impl;
-
-  sve_s8qs_nhwc_3x3_s1_output2x2_dot_depthfirst(const CPUInfo *) {}
+  void pack_parameters(
+    const DepthwiseArgs &args, void *buffer, const void *biases, const arm_gemm::Requantize32 &qp,
+    const void *weights, size_t ld_weight_col, size_t ld_weight_row
+  ) const override
+  {
+    interleave_sve_s8q_3x3_dot::pack_parameters(
+      args.input_channels, buffer, reinterpret_cast<const int32_t *>(biases),
+      reinterpret_cast<const int8_t *>(weights), qp, ld_weight_col, ld_weight_row
+    );
+  }
 };
 
 }  // namespace depthwise
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_s8qs_nhwc_3x3_s1_output2x2_dot_depthfirst/generic.cpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_s8qs_nhwc_3x3_s1_output2x2_dot_depthfirst/generic.cpp
index 095c1de..391e98b 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_s8qs_nhwc_3x3_s1_output2x2_dot_depthfirst/generic.cpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_s8qs_nhwc_3x3_s1_output2x2_dot_depthfirst/generic.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -30,7 +30,15 @@
 namespace arm_conv {
 namespace depthwise {
 
-void sve_s8qs_nhwc_3x3_s1_output2x2_dot_depthfirst_impl(const int8_t *const *const inptrs, int8_t *const *const outptrs, const void *params, const uint64_t n_channels, const arm_gemm::Requantize32& qp)
+void sve_s8qs_nhwc_3x3_s1_output2x2_dot_depthfirst_impl(
+  const unsigned int n_channels,
+  const int8_t *const *const inptrs,
+  const int8_t *params,
+  const int32_t *,  // Bias, should be wrapped into the parameters
+  const arm_gemm::Requantize32& qp,
+  const int32_t *, const int32_t *,  // Requant parameters, also wrapped
+  int8_t *const *const outptrs
+)
 {
   __asm__ __volatile__(
     "ldp x11, x10, [%x[inptrs], #0x0]\n"
@@ -377,7 +385,7 @@
     "b.any 1b\n"
     "addvl SP, SP, #8\n"
     : [params] "+&r" (params)
-    : [inptrs] "r" (inptrs), [n_channels] "r" (n_channels), [offsetof_Requantize32_c_offset] "I" (offsetof(arm_gemm::Requantize32, c_offset)), [offsetof_Requantize32_maxval] "I" (offsetof(arm_gemm::Requantize32, maxval)), [offsetof_Requantize32_minval] "I" (offsetof(arm_gemm::Requantize32, minval)), [outptrs] "r" (outptrs), [qp] "r" (&qp)
+    : [inptrs] "r" (inptrs), [n_channels] "r" ((long unsigned int) n_channels), [offsetof_Requantize32_c_offset] "I" (offsetof(arm_gemm::Requantize32, c_offset)), [offsetof_Requantize32_maxval] "I" (offsetof(arm_gemm::Requantize32, maxval)), [offsetof_Requantize32_minval] "I" (offsetof(arm_gemm::Requantize32, minval)), [outptrs] "r" (outptrs), [qp] "r" (&qp)
     : "cc", "memory", "p0", "p1", "p2", "x9", "x10", "x11", "x19", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"
   );
 }
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_u8q_nhwc_3x3_s1_output2x2_dot_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_u8q_nhwc_3x3_s1_output2x2_dot_depthfirst.hpp
index a087e80..648b2da 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_u8q_nhwc_3x3_s1_output2x2_dot_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_u8q_nhwc_3x3_s1_output2x2_dot_depthfirst.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -34,39 +34,40 @@
 namespace arm_conv {
 namespace depthwise {
 
-void sve_u8q_nhwc_3x3_s1_output2x2_dot_depthfirst_impl(const uint8_t *const *, uint8_t *const *, const void *, uint64_t, const arm_gemm::Requantize32&);
+void sve_u8q_nhwc_3x3_s1_output2x2_dot_depthfirst_impl(unsigned int, const uint8_t *const *, const uint8_t *, const int32_t *, const arm_gemm::Requantize32&, const int32_t *, const int32_t *, uint8_t *const *);
 
-struct sve_u8q_nhwc_3x3_s1_output2x2_dot_depthfirst
+class sve_u8q_nhwc_3x3_s1_output2x2_dot_depthfirst : public DepthwiseDepthfirstStrategy<uint8_t, uint8_t, uint8_t, int32_t>
 {
-  typedef uint32_t bias_type;
-  typedef uint8_t input_type;
-  typedef uint8_t weight_type;
-  typedef uint8_t return_type;
+  using Parent = DepthwiseDepthfirstStrategy<uint8_t, uint8_t, uint8_t, int32_t>;
 
-  constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::SVE;
-
-  typedef void (*kern_type)(const uint8_t *const *, uint8_t *const *, const void *, uint64_t, const arm_gemm::Requantize32&);
-  typedef void (*parameter_packing_fn)(unsigned int, void *, const int32_t *, const uint8_t *, const arm_gemm::Requantize32 &, size_t, size_t);
-  typedef size_t (*parameter_sizing_fn)(const DepthwiseArgs &);
-
+  public:
   constexpr static unsigned int kernel_rows = 3;
   constexpr static unsigned int kernel_cols = 3;
 
   constexpr static unsigned int stride_rows = 1;
   constexpr static unsigned int stride_cols = 1;
 
-  constexpr static unsigned int output_rows = 2;
-  constexpr static unsigned int output_cols = 2;
+  sve_u8q_nhwc_3x3_s1_output2x2_dot_depthfirst(const CPUInfo *) : Parent(2, 2, 3, 3, 1, 1) {}
 
-  constexpr static unsigned int input_rows = 4;
-  constexpr static unsigned int input_cols = 4;
+  arm_gemm::VLType get_vl_type(void) const override { return arm_gemm::VLType::SVE; }
 
-  constexpr static parameter_packing_fn pack_parameters = interleave_sve_u8q_3x3_dot::pack_parameters;
-  constexpr static parameter_sizing_fn get_packed_size = interleave_sve_u8q_3x3_dot::get_packed_size;
+  Parent::KernelType kernel = sve_u8q_nhwc_3x3_s1_output2x2_dot_depthfirst_impl;
+  Parent::KernelType get_kernel(void) const override { return kernel; }
+  size_t get_storage_size(const DepthwiseArgs &args) const override
+  {
+    return interleave_sve_u8q_3x3_dot::get_packed_size(args);
+  }
 
-  kern_type kernel = sve_u8q_nhwc_3x3_s1_output2x2_dot_depthfirst_impl;
-
-  sve_u8q_nhwc_3x3_s1_output2x2_dot_depthfirst(const CPUInfo *) {}
+  void pack_parameters(
+    const DepthwiseArgs &args, void *buffer, const void *biases, const arm_gemm::Requantize32 &qp,
+    const void *weights, size_t ld_weight_col, size_t ld_weight_row
+  ) const override
+  {
+    interleave_sve_u8q_3x3_dot::pack_parameters(
+      args.input_channels, buffer, reinterpret_cast<const int32_t *>(biases),
+      reinterpret_cast<const uint8_t *>(weights), qp, ld_weight_col, ld_weight_row
+    );
+  }
 };
 
 }  // namespace depthwise
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_u8q_nhwc_3x3_s1_output2x2_dot_depthfirst/generic.cpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_u8q_nhwc_3x3_s1_output2x2_dot_depthfirst/generic.cpp
index 0d4b9e6..440f57e 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_u8q_nhwc_3x3_s1_output2x2_dot_depthfirst/generic.cpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_u8q_nhwc_3x3_s1_output2x2_dot_depthfirst/generic.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -30,7 +30,15 @@
 namespace arm_conv {
 namespace depthwise {
 
-void sve_u8q_nhwc_3x3_s1_output2x2_dot_depthfirst_impl(const uint8_t *const *const inptrs, uint8_t *const *const outptrs, const void *params, const uint64_t n_channels, const arm_gemm::Requantize32& qp)
+void sve_u8q_nhwc_3x3_s1_output2x2_dot_depthfirst_impl(
+  const unsigned int n_channels,
+  const uint8_t *const *const inptrs,
+  const uint8_t *params,
+  const int32_t *,  // Bias, should be wrapped into the parameters
+  const arm_gemm::Requantize32& qp,
+  const int32_t *, const int32_t *,  // Requant parameters, also wrapped
+  uint8_t *const *const outptrs
+)
 {
   __asm__ __volatile__(
     "ldp x11, x10, [%x[inptrs], #0x0]\n"
@@ -446,7 +454,7 @@
     "b.any 1b\n"
     "addvl SP, SP, #8\n"
     : [params] "+&r" (params)
-    : [inptrs] "r" (inptrs), [n_channels] "r" (n_channels), [offsetof_Requantize32_b_offset] "I" (offsetof(arm_gemm::Requantize32, b_offset)), [offsetof_Requantize32_c_offset] "I" (offsetof(arm_gemm::Requantize32, c_offset)), [offsetof_Requantize32_maxval] "I" (offsetof(arm_gemm::Requantize32, maxval)), [offsetof_Requantize32_minval] "I" (offsetof(arm_gemm::Requantize32, minval)), [outptrs] "r" (outptrs), [qp] "r" (&qp)
+    : [inptrs] "r" (inptrs), [n_channels] "r" ((long unsigned int) n_channels), [offsetof_Requantize32_b_offset] "I" (offsetof(arm_gemm::Requantize32, b_offset)), [offsetof_Requantize32_c_offset] "I" (offsetof(arm_gemm::Requantize32, c_offset)), [offsetof_Requantize32_maxval] "I" (offsetof(arm_gemm::Requantize32, maxval)), [offsetof_Requantize32_minval] "I" (offsetof(arm_gemm::Requantize32, minval)), [outptrs] "r" (outptrs), [qp] "r" (&qp)
     : "cc", "memory", "p0", "p1", "p2", "x9", "x10", "x11", "x19", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"
   );
 }
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_u8q_nhwc_3x3_s1_output2x2_mla_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_u8q_nhwc_3x3_s1_output2x2_mla_depthfirst.hpp
index b524fd7..1cf20ef 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_u8q_nhwc_3x3_s1_output2x2_mla_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_u8q_nhwc_3x3_s1_output2x2_mla_depthfirst.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -36,37 +36,25 @@
 
 void sve_u8q_nhwc_3x3_s1_output2x2_mla_depthfirst_impl(unsigned int, const uint8_t *const *, const uint8_t *, const int32_t *, const arm_gemm::Requantize32 &, const int32_t *, const int32_t *, uint8_t *const *);
 
-struct sve_u8q_nhwc_3x3_s1_output2x2_mla_depthfirst
+class sve_u8q_nhwc_3x3_s1_output2x2_mla_depthfirst : public DepthwiseDepthfirstStrategy<uint8_t, uint8_t, uint8_t, int32_t>
 {
-  typedef int32_t bias_type;
-  typedef uint8_t input_type;
-  typedef uint8_t weight_type;
-  typedef uint8_t return_type;
+  using Parent = DepthwiseDepthfirstStrategy<uint8_t, uint8_t, uint8_t, int32_t>;
 
-  constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::SVE;
-
-  typedef void (*kern_type)(unsigned int, const uint8_t *const *, const uint8_t *, const int32_t *, const arm_gemm::Requantize32 &, const int32_t *, const int32_t *, uint8_t *const *);
-  typedef void (*parameter_packing_fn)(unsigned int, void *, const uint8_t *, size_t, size_t);
-  typedef size_t (*parameter_sizing_fn)(const DepthwiseArgs &);
-
+  public:
   constexpr static unsigned int kernel_rows = 3;
   constexpr static unsigned int kernel_cols = 3;
 
   constexpr static unsigned int stride_rows = 1;
   constexpr static unsigned int stride_cols = 1;
 
-  constexpr static unsigned int output_rows = 2;
-  constexpr static unsigned int output_cols = 2;
+  arm_gemm::VLType get_vl_type(void) const override { return arm_gemm::VLType::SVE; }
+  unsigned int get_accumulator_depth_vl(void) const override { return 2; }
 
-  constexpr static unsigned int input_rows = 4;
-  constexpr static unsigned int input_cols = 4;
+  sve_u8q_nhwc_3x3_s1_output2x2_mla_depthfirst(const CPUInfo *) : Parent(2, 2, 3, 3, 1, 1) {}
 
-  constexpr static parameter_packing_fn pack_parameters = interleave_sve_u8q_3x3_mla::pack_parameters;
-  constexpr static parameter_sizing_fn get_packed_size = interleave_sve_u8q_3x3_mla::get_packed_size;
+  Parent::KernelType kernel = sve_u8q_nhwc_3x3_s1_output2x2_mla_depthfirst_impl;
 
-  kern_type kernel = sve_u8q_nhwc_3x3_s1_output2x2_mla_depthfirst_impl;
-
-  sve_u8q_nhwc_3x3_s1_output2x2_mla_depthfirst(const CPUInfo *) {}
+  Parent::KernelType get_kernel(void) const override { return kernel; }
 };
 
 }  // namespace depthwise
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_u8q_nhwc_3x3_s1_output2x2_mla_depthfirst/generic.cpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_u8q_nhwc_3x3_s1_output2x2_mla_depthfirst/generic.cpp
index 52dc468..7bfa5fc 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_u8q_nhwc_3x3_s1_output2x2_mla_depthfirst/generic.cpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_u8q_nhwc_3x3_s1_output2x2_mla_depthfirst/generic.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -27,7 +27,7 @@
 #include <cstddef>
 #include <cstdint>
 
-#if defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE) && defined(ARM_COMPUTE_ENABLE_SVE2)
+#if defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE)
 
 namespace arm_conv {
 namespace depthwise {
@@ -415,4 +415,4 @@
 }  // namespace depthwise
 }  // namespace arm_conv
 
-#endif  // defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE) && defined(ARM_COMPUTE_ENABLE_SVE2)
+#endif  // defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE)
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_u8q_nhwc_3x3_s2_output2x2_mla_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_u8q_nhwc_3x3_s2_output2x2_mla_depthfirst.hpp
index 9818642..a794095 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_u8q_nhwc_3x3_s2_output2x2_mla_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_u8q_nhwc_3x3_s2_output2x2_mla_depthfirst.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -29,47 +29,35 @@
 
 #pragma once
 
-#if defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE) && defined(ARM_COMPUTE_ENABLE_SVE2)
+#if defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE)
 
 namespace arm_conv {
 namespace depthwise {
 
 void sve_u8q_nhwc_3x3_s2_output2x2_mla_depthfirst_impl(unsigned int, const uint8_t *const *, const uint8_t *, const int32_t *, const arm_gemm::Requantize32 &, const int32_t *, const int32_t *, uint8_t *const *);
 
-struct sve_u8q_nhwc_3x3_s2_output2x2_mla_depthfirst
+class sve_u8q_nhwc_3x3_s2_output2x2_mla_depthfirst : public DepthwiseDepthfirstStrategy<uint8_t, uint8_t, uint8_t, int32_t>
 {
-  typedef int32_t bias_type;
-  typedef uint8_t input_type;
-  typedef uint8_t weight_type;
-  typedef uint8_t return_type;
+  using Parent = DepthwiseDepthfirstStrategy<uint8_t, uint8_t, uint8_t, int32_t>;
 
-  constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::SVE;
-
-  typedef void (*kern_type)(unsigned int, const uint8_t *const *, const uint8_t *, const int32_t *, const arm_gemm::Requantize32 &, const int32_t *, const int32_t *, uint8_t *const *);
-  typedef void (*parameter_packing_fn)(unsigned int, void *, const uint8_t *, size_t, size_t);
-  typedef size_t (*parameter_sizing_fn)(const DepthwiseArgs &);
-
+  public:
   constexpr static unsigned int kernel_rows = 3;
   constexpr static unsigned int kernel_cols = 3;
 
   constexpr static unsigned int stride_rows = 2;
   constexpr static unsigned int stride_cols = 2;
 
-  constexpr static unsigned int output_rows = 2;
-  constexpr static unsigned int output_cols = 2;
+  arm_gemm::VLType get_vl_type(void) const override { return arm_gemm::VLType::SVE; }
+  unsigned int get_accumulator_depth_vl(void) const override { return 2; }
 
-  constexpr static unsigned int input_rows = 5;
-  constexpr static unsigned int input_cols = 5;
+  sve_u8q_nhwc_3x3_s2_output2x2_mla_depthfirst(const CPUInfo *) : Parent(2, 2, 3, 3, 2, 2) {}
 
-  constexpr static parameter_packing_fn pack_parameters = interleave_sve_u8q_3x3_mla::pack_parameters;
-  constexpr static parameter_sizing_fn get_packed_size = interleave_sve_u8q_3x3_mla::get_packed_size;
+  Parent::KernelType kernel = sve_u8q_nhwc_3x3_s2_output2x2_mla_depthfirst_impl;
 
-  kern_type kernel = sve_u8q_nhwc_3x3_s2_output2x2_mla_depthfirst_impl;
-
-  sve_u8q_nhwc_3x3_s2_output2x2_mla_depthfirst(const CPUInfo *) {}
+  Parent::KernelType get_kernel(void) const override { return kernel; }
 };
 
 }  // namespace depthwise
 }  // namespace arm_conv
 
-#endif  // defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE) && defined(ARM_COMPUTE_ENABLE_SVE2)
+#endif  // defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE)
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_u8q_nhwc_3x3_s2_output2x2_mla_depthfirst/generic.cpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_u8q_nhwc_3x3_s2_output2x2_mla_depthfirst/generic.cpp
index 34ba8ec..e1b2d25 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_u8q_nhwc_3x3_s2_output2x2_mla_depthfirst/generic.cpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_u8q_nhwc_3x3_s2_output2x2_mla_depthfirst/generic.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -27,7 +27,7 @@
 #include <cstddef>
 #include <cstdint>
 
-#if defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE) && defined(ARM_COMPUTE_ENABLE_SVE2)
+#if defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE)
 
 namespace arm_conv {
 namespace depthwise {
@@ -456,4 +456,4 @@
 }  // namespace depthwise
 }  // namespace arm_conv
 
-#endif  // defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE) && defined(ARM_COMPUTE_ENABLE_SVE2)
+#endif  // defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE)
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_u8q_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_u8q_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp
index b1b16c5..ac0a00b 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_u8q_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_u8q_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -29,47 +29,35 @@
 
 #pragma once
 
-#if defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE) && defined(ARM_COMPUTE_ENABLE_SVE2)
+#if defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE)
 
 namespace arm_conv {
 namespace depthwise {
 
 void sve_u8q_nhwc_5x5_s1_output2x2_mla_depthfirst_impl(unsigned int, const uint8_t *const *, const uint8_t *, const int32_t *, const arm_gemm::Requantize32 &, const int32_t *, const int32_t *, uint8_t *const *);
 
-struct sve_u8q_nhwc_5x5_s1_output2x2_mla_depthfirst
+class sve_u8q_nhwc_5x5_s1_output2x2_mla_depthfirst : public DepthwiseDepthfirstStrategy<uint8_t, uint8_t, uint8_t, int32_t>
 {
-  typedef int32_t bias_type;
-  typedef uint8_t input_type;
-  typedef uint8_t weight_type;
-  typedef uint8_t return_type;
+  using Parent = DepthwiseDepthfirstStrategy<uint8_t, uint8_t, uint8_t, int32_t>;
 
-  constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::SVE;
-
-  typedef void (*kern_type)(unsigned int, const uint8_t *const *, const uint8_t *, const int32_t *, const arm_gemm::Requantize32 &, const int32_t *, const int32_t *, uint8_t *const *);
-  typedef void (*parameter_packing_fn)(unsigned int, void *, const uint8_t *, size_t, size_t);
-  typedef size_t (*parameter_sizing_fn)(const DepthwiseArgs &);
-
+  public:
   constexpr static unsigned int kernel_rows = 5;
   constexpr static unsigned int kernel_cols = 5;
 
   constexpr static unsigned int stride_rows = 1;
   constexpr static unsigned int stride_cols = 1;
 
-  constexpr static unsigned int output_rows = 2;
-  constexpr static unsigned int output_cols = 2;
+  arm_gemm::VLType get_vl_type(void) const override { return arm_gemm::VLType::SVE; }
+  unsigned int get_accumulator_depth_vl(void) const override { return 2; }
 
-  constexpr static unsigned int input_rows = 6;
-  constexpr static unsigned int input_cols = 6;
+  sve_u8q_nhwc_5x5_s1_output2x2_mla_depthfirst(const CPUInfo *) : Parent(2, 2, 5, 5, 1, 1) {}
 
-  constexpr static parameter_packing_fn pack_parameters = interleave_sve_u8q_5x5_mla::pack_parameters;
-  constexpr static parameter_sizing_fn get_packed_size = interleave_sve_u8q_5x5_mla::get_packed_size;
+  Parent::KernelType kernel = sve_u8q_nhwc_5x5_s1_output2x2_mla_depthfirst_impl;
 
-  kern_type kernel = sve_u8q_nhwc_5x5_s1_output2x2_mla_depthfirst_impl;
-
-  sve_u8q_nhwc_5x5_s1_output2x2_mla_depthfirst(const CPUInfo *) {}
+  Parent::KernelType get_kernel(void) const override { return kernel; }
 };
 
 }  // namespace depthwise
 }  // namespace arm_conv
 
-#endif  // defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE) && defined(ARM_COMPUTE_ENABLE_SVE2)
+#endif  // defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE)
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_u8q_nhwc_5x5_s1_output2x2_mla_depthfirst/generic.cpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_u8q_nhwc_5x5_s1_output2x2_mla_depthfirst/generic.cpp
index 441da6d..0b2182f 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_u8q_nhwc_5x5_s1_output2x2_mla_depthfirst/generic.cpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_u8q_nhwc_5x5_s1_output2x2_mla_depthfirst/generic.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -27,7 +27,7 @@
 #include <cstddef>
 #include <cstdint>
 
-#if defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE) && defined(ARM_COMPUTE_ENABLE_SVE2)
+#if defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE)
 
 namespace arm_conv {
 namespace depthwise {
@@ -657,4 +657,4 @@
 }  // namespace depthwise
 }  // namespace arm_conv
 
-#endif  // defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE) && defined(ARM_COMPUTE_ENABLE_SVE2)
+#endif  // defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE)
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_u8q_packed_to_nhwc_3x3_s2_with_multiplier_output2x4_dot_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_u8q_packed_to_nhwc_3x3_s2_with_multiplier_output2x4_dot_depthfirst.hpp
index dbf70c3..81c954a 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_u8q_packed_to_nhwc_3x3_s2_with_multiplier_output2x4_dot_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_u8q_packed_to_nhwc_3x3_s2_with_multiplier_output2x4_dot_depthfirst.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -35,33 +35,24 @@
 
 void sve_u8q_packed_to_nhwc_3x3_s2_with_multiplier_output2x4_dot_depthfirst_impl(const uint8_t *const *const, uint8_t *const *const, const void *, unsigned int, const arm_gemm::Requantize32&);
 
-struct sve_u8q_packed_to_nhwc_3x3_s2_with_multiplier_output2x4_dot_depthfirst
+struct sve_u8q_packed_to_nhwc_3x3_s2_with_multiplier_output2x4_dot_depthfirst : DepthfirstMultiplierStrategy<uint8_t, uint8_t, uint8_t, int32_t>
 {
-  typedef uint32_t bias_type;
-  typedef uint8_t input_type;
-  typedef uint8_t weight_type;
-  typedef uint8_t return_type;
-
-  typedef void (*kern_type)(const uint8_t *const *const, uint8_t *const *const, const void *, unsigned int, const arm_gemm::Requantize32&);
-
-  constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::SVE;
-
+  using Parent = DepthfirstMultiplierStrategy<uint8_t, uint8_t, uint8_t, int32_t>;
   constexpr static unsigned int kernel_rows = 3;
   constexpr static unsigned int kernel_cols = 3;
 
   constexpr static unsigned int stride_rows = 2;
   constexpr static unsigned int stride_cols = 2;
 
-  constexpr static unsigned int output_rows = 2;
-  constexpr static unsigned int output_cols = 4;
+  sve_u8q_packed_to_nhwc_3x3_s2_with_multiplier_output2x4_dot_depthfirst(const CPUInfo *)
+  : Parent(2, 4, kernel_rows, kernel_cols, stride_rows, stride_cols)
+  {
+  }
 
-  constexpr static unsigned int input_rows = 5;
-  constexpr static unsigned int input_cols = 9;
-  constexpr static unsigned int input_col_quads = 1;
+  arm_gemm::VLType get_vl_type() const override { return arm_gemm::VLType::SVE; }
 
-  kern_type kernel = sve_u8q_packed_to_nhwc_3x3_s2_with_multiplier_output2x4_dot_depthfirst_impl;
-
-  sve_u8q_packed_to_nhwc_3x3_s2_with_multiplier_output2x4_dot_depthfirst(const CPUInfo *) {}
+  Parent::KernelType kernel = sve_u8q_packed_to_nhwc_3x3_s2_with_multiplier_output2x4_dot_depthfirst_impl;
+  Parent::KernelType get_kernel(void) const override { return kernel; }
 };
 
 }  // namespace depthwise
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_u8q_packed_to_nhwc_5x5_s1_with_multiplier_output4x2_dot_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_u8q_packed_to_nhwc_5x5_s1_with_multiplier_output4x2_dot_depthfirst.hpp
index 90fefdc..e7173de 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_u8q_packed_to_nhwc_5x5_s1_with_multiplier_output4x2_dot_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_u8q_packed_to_nhwc_5x5_s1_with_multiplier_output4x2_dot_depthfirst.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -35,33 +35,24 @@
 
 void sve_u8q_packed_to_nhwc_5x5_s1_with_multiplier_output4x2_dot_depthfirst_impl(const uint8_t *const *const, uint8_t *const *const, const void *, unsigned int, const arm_gemm::Requantize32&);
 
-struct sve_u8q_packed_to_nhwc_5x5_s1_with_multiplier_output4x2_dot_depthfirst
+struct sve_u8q_packed_to_nhwc_5x5_s1_with_multiplier_output4x2_dot_depthfirst : DepthfirstMultiplierStrategy<uint8_t, uint8_t, uint8_t, int32_t>
 {
-  typedef uint32_t bias_type;
-  typedef uint8_t input_type;
-  typedef uint8_t weight_type;
-  typedef uint8_t return_type;
-
-  typedef void (*kern_type)(const uint8_t *const *const, uint8_t *const *const, const void *, unsigned int, const arm_gemm::Requantize32&);
-
-  constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::SVE;
-
+  using Parent = DepthfirstMultiplierStrategy<uint8_t, uint8_t, uint8_t, int32_t>;
   constexpr static unsigned int kernel_rows = 5;
   constexpr static unsigned int kernel_cols = 5;
 
   constexpr static unsigned int stride_rows = 1;
   constexpr static unsigned int stride_cols = 1;
 
-  constexpr static unsigned int output_rows = 4;
-  constexpr static unsigned int output_cols = 2;
+  sve_u8q_packed_to_nhwc_5x5_s1_with_multiplier_output4x2_dot_depthfirst(const CPUInfo *)
+  : Parent(4, 2, kernel_rows, kernel_cols, stride_rows, stride_cols)
+  {
+  }
 
-  constexpr static unsigned int input_rows = 8;
-  constexpr static unsigned int input_cols = 6;
-  constexpr static unsigned int input_col_quads = 1;
+  arm_gemm::VLType get_vl_type() const override { return arm_gemm::VLType::SVE; }
 
-  kern_type kernel = sve_u8q_packed_to_nhwc_5x5_s1_with_multiplier_output4x2_dot_depthfirst_impl;
-
-  sve_u8q_packed_to_nhwc_5x5_s1_with_multiplier_output4x2_dot_depthfirst(const CPUInfo *) {}
+  Parent::KernelType kernel = sve_u8q_packed_to_nhwc_5x5_s1_with_multiplier_output4x2_dot_depthfirst_impl;
+  Parent::KernelType get_kernel(void) const override { return kernel; }
 };
 
 }  // namespace depthwise
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_u8s8u8q_nhwc_3x3_s1_output2x2_mla_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_u8s8u8q_nhwc_3x3_s1_output2x2_mla_depthfirst.hpp
index 8ab2e5b..3d475da 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_u8s8u8q_nhwc_3x3_s1_output2x2_mla_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_u8s8u8q_nhwc_3x3_s1_output2x2_mla_depthfirst.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -29,47 +29,35 @@
 
 #pragma once
 
-#if defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE) && defined(ARM_COMPUTE_ENABLE_SVE2)
+#if defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE)
 
 namespace arm_conv {
 namespace depthwise {
 
 void sve_u8s8u8q_nhwc_3x3_s1_output2x2_mla_depthfirst_impl(unsigned int, const uint8_t *const *, const int8_t *, const int32_t *, const arm_gemm::Requantize32 &, const int32_t *, const int32_t *, uint8_t *const *);
 
-struct sve_u8s8u8q_nhwc_3x3_s1_output2x2_mla_depthfirst
+class sve_u8s8u8q_nhwc_3x3_s1_output2x2_mla_depthfirst : public DepthwiseDepthfirstStrategy<uint8_t, int8_t, uint8_t, int32_t>
 {
-  typedef int32_t bias_type;
-  typedef uint8_t input_type;
-  typedef int8_t weight_type;
-  typedef uint8_t return_type;
+  using Parent = DepthwiseDepthfirstStrategy<uint8_t, int8_t, uint8_t, int32_t>;
 
-  constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::SVE;
-
-  typedef void (*kern_type)(unsigned int, const uint8_t *const *, const int8_t *, const int32_t *, const arm_gemm::Requantize32 &, const int32_t *, const int32_t *, uint8_t *const *);
-  typedef void (*parameter_packing_fn)(unsigned int, void *, const int8_t *, size_t, size_t);
-  typedef size_t (*parameter_sizing_fn)(const DepthwiseArgs &);
-
+  public:
   constexpr static unsigned int kernel_rows = 3;
   constexpr static unsigned int kernel_cols = 3;
 
   constexpr static unsigned int stride_rows = 1;
   constexpr static unsigned int stride_cols = 1;
 
-  constexpr static unsigned int output_rows = 2;
-  constexpr static unsigned int output_cols = 2;
+  arm_gemm::VLType get_vl_type(void) const override { return arm_gemm::VLType::SVE; }
+  unsigned int get_accumulator_depth_vl(void) const override { return 2; }
 
-  constexpr static unsigned int input_rows = 4;
-  constexpr static unsigned int input_cols = 4;
+  sve_u8s8u8q_nhwc_3x3_s1_output2x2_mla_depthfirst(const CPUInfo *) : Parent(2, 2, 3, 3, 1, 1) {}
 
-  constexpr static parameter_packing_fn pack_parameters = interleave_sve_s8q_3x3_mla::pack_parameters;
-  constexpr static parameter_sizing_fn get_packed_size = interleave_sve_s8q_3x3_mla::get_packed_size;
+  Parent::KernelType kernel = sve_u8s8u8q_nhwc_3x3_s1_output2x2_mla_depthfirst_impl;
 
-  kern_type kernel = sve_u8s8u8q_nhwc_3x3_s1_output2x2_mla_depthfirst_impl;
-
-  sve_u8s8u8q_nhwc_3x3_s1_output2x2_mla_depthfirst(const CPUInfo *) {}
+  Parent::KernelType get_kernel(void) const override { return kernel; }
 };
 
 }  // namespace depthwise
 }  // namespace arm_conv
 
-#endif  // defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE) && defined(ARM_COMPUTE_ENABLE_SVE2)
+#endif  // defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE)
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_u8s8u8q_nhwc_3x3_s1_output2x2_mla_depthfirst/generic.cpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_u8s8u8q_nhwc_3x3_s1_output2x2_mla_depthfirst/generic.cpp
index 4b9be8f..dc8fad9 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_u8s8u8q_nhwc_3x3_s1_output2x2_mla_depthfirst/generic.cpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_u8s8u8q_nhwc_3x3_s1_output2x2_mla_depthfirst/generic.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -27,7 +27,7 @@
 #include <cstddef>
 #include <cstdint>
 
-#if defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE) && defined(ARM_COMPUTE_ENABLE_SVE2)
+#if defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE)
 
 namespace arm_conv {
 namespace depthwise {
@@ -415,4 +415,4 @@
 }  // namespace depthwise
 }  // namespace arm_conv
 
-#endif  // defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE) && defined(ARM_COMPUTE_ENABLE_SVE2)
+#endif  // defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE)
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_u8s8u8q_nhwc_3x3_s2_output2x2_mla_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_u8s8u8q_nhwc_3x3_s2_output2x2_mla_depthfirst.hpp
index f652e48..9a3db20 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_u8s8u8q_nhwc_3x3_s2_output2x2_mla_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_u8s8u8q_nhwc_3x3_s2_output2x2_mla_depthfirst.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -29,47 +29,35 @@
 
 #pragma once
 
-#if defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE) && defined(ARM_COMPUTE_ENABLE_SVE2)
+#if defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE)
 
 namespace arm_conv {
 namespace depthwise {
 
 void sve_u8s8u8q_nhwc_3x3_s2_output2x2_mla_depthfirst_impl(unsigned int, const uint8_t *const *, const int8_t *, const int32_t *, const arm_gemm::Requantize32 &, const int32_t *, const int32_t *, uint8_t *const *);
 
-struct sve_u8s8u8q_nhwc_3x3_s2_output2x2_mla_depthfirst
+class sve_u8s8u8q_nhwc_3x3_s2_output2x2_mla_depthfirst : public DepthwiseDepthfirstStrategy<uint8_t, int8_t, uint8_t, int32_t>
 {
-  typedef int32_t bias_type;
-  typedef uint8_t input_type;
-  typedef int8_t weight_type;
-  typedef uint8_t return_type;
+  using Parent = DepthwiseDepthfirstStrategy<uint8_t, int8_t, uint8_t, int32_t>;
 
-  constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::SVE;
-
-  typedef void (*kern_type)(unsigned int, const uint8_t *const *, const int8_t *, const int32_t *, const arm_gemm::Requantize32 &, const int32_t *, const int32_t *, uint8_t *const *);
-  typedef void (*parameter_packing_fn)(unsigned int, void *, const int8_t *, size_t, size_t);
-  typedef size_t (*parameter_sizing_fn)(const DepthwiseArgs &);
-
+  public:
   constexpr static unsigned int kernel_rows = 3;
   constexpr static unsigned int kernel_cols = 3;
 
   constexpr static unsigned int stride_rows = 2;
   constexpr static unsigned int stride_cols = 2;
 
-  constexpr static unsigned int output_rows = 2;
-  constexpr static unsigned int output_cols = 2;
+  arm_gemm::VLType get_vl_type(void) const override { return arm_gemm::VLType::SVE; }
+  unsigned int get_accumulator_depth_vl(void) const override { return 2; }
 
-  constexpr static unsigned int input_rows = 5;
-  constexpr static unsigned int input_cols = 5;
+  sve_u8s8u8q_nhwc_3x3_s2_output2x2_mla_depthfirst(const CPUInfo *) : Parent(2, 2, 3, 3, 2, 2) {}
 
-  constexpr static parameter_packing_fn pack_parameters = interleave_sve_s8q_3x3_mla::pack_parameters;
-  constexpr static parameter_sizing_fn get_packed_size = interleave_sve_s8q_3x3_mla::get_packed_size;
+  Parent::KernelType kernel = sve_u8s8u8q_nhwc_3x3_s2_output2x2_mla_depthfirst_impl;
 
-  kern_type kernel = sve_u8s8u8q_nhwc_3x3_s2_output2x2_mla_depthfirst_impl;
-
-  sve_u8s8u8q_nhwc_3x3_s2_output2x2_mla_depthfirst(const CPUInfo *) {}
+  Parent::KernelType get_kernel(void) const override { return kernel; }
 };
 
 }  // namespace depthwise
 }  // namespace arm_conv
 
-#endif  // defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE) && defined(ARM_COMPUTE_ENABLE_SVE2)
+#endif  // defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE)
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_u8s8u8q_nhwc_3x3_s2_output2x2_mla_depthfirst/generic.cpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_u8s8u8q_nhwc_3x3_s2_output2x2_mla_depthfirst/generic.cpp
index 400e62d..9adf100 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_u8s8u8q_nhwc_3x3_s2_output2x2_mla_depthfirst/generic.cpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_u8s8u8q_nhwc_3x3_s2_output2x2_mla_depthfirst/generic.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -27,7 +27,7 @@
 #include <cstddef>
 #include <cstdint>
 
-#if defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE) && defined(ARM_COMPUTE_ENABLE_SVE2)
+#if defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE)
 
 namespace arm_conv {
 namespace depthwise {
@@ -456,4 +456,4 @@
 }  // namespace depthwise
 }  // namespace arm_conv
 
-#endif  // defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE) && defined(ARM_COMPUTE_ENABLE_SVE2)
+#endif  // defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE)
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_u8s8u8q_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_u8s8u8q_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp
index f07ea13..06ca42e 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_u8s8u8q_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_u8s8u8q_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -29,47 +29,35 @@
 
 #pragma once
 
-#if defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE) && defined(ARM_COMPUTE_ENABLE_SVE2)
+#if defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE)
 
 namespace arm_conv {
 namespace depthwise {
 
 void sve_u8s8u8q_nhwc_5x5_s1_output2x2_mla_depthfirst_impl(unsigned int, const uint8_t *const *, const int8_t *, const int32_t *, const arm_gemm::Requantize32 &, const int32_t *, const int32_t *, uint8_t *const *);
 
-struct sve_u8s8u8q_nhwc_5x5_s1_output2x2_mla_depthfirst
+class sve_u8s8u8q_nhwc_5x5_s1_output2x2_mla_depthfirst : public DepthwiseDepthfirstStrategy<uint8_t, int8_t, uint8_t, int32_t>
 {
-  typedef int32_t bias_type;
-  typedef uint8_t input_type;
-  typedef int8_t weight_type;
-  typedef uint8_t return_type;
+  using Parent = DepthwiseDepthfirstStrategy<uint8_t, int8_t, uint8_t, int32_t>;
 
-  constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::SVE;
-
-  typedef void (*kern_type)(unsigned int, const uint8_t *const *, const int8_t *, const int32_t *, const arm_gemm::Requantize32 &, const int32_t *, const int32_t *, uint8_t *const *);
-  typedef void (*parameter_packing_fn)(unsigned int, void *, const int8_t *, size_t, size_t);
-  typedef size_t (*parameter_sizing_fn)(const DepthwiseArgs &);
-
+  public:
   constexpr static unsigned int kernel_rows = 5;
   constexpr static unsigned int kernel_cols = 5;
 
   constexpr static unsigned int stride_rows = 1;
   constexpr static unsigned int stride_cols = 1;
 
-  constexpr static unsigned int output_rows = 2;
-  constexpr static unsigned int output_cols = 2;
+  arm_gemm::VLType get_vl_type(void) const override { return arm_gemm::VLType::SVE; }
+  unsigned int get_accumulator_depth_vl(void) const override { return 2; }
 
-  constexpr static unsigned int input_rows = 6;
-  constexpr static unsigned int input_cols = 6;
+  sve_u8s8u8q_nhwc_5x5_s1_output2x2_mla_depthfirst(const CPUInfo *) : Parent(2, 2, 5, 5, 1, 1) {}
 
-  constexpr static parameter_packing_fn pack_parameters = interleave_sve_s8q_5x5_mla::pack_parameters;
-  constexpr static parameter_sizing_fn get_packed_size = interleave_sve_s8q_5x5_mla::get_packed_size;
+  Parent::KernelType kernel = sve_u8s8u8q_nhwc_5x5_s1_output2x2_mla_depthfirst_impl;
 
-  kern_type kernel = sve_u8s8u8q_nhwc_5x5_s1_output2x2_mla_depthfirst_impl;
-
-  sve_u8s8u8q_nhwc_5x5_s1_output2x2_mla_depthfirst(const CPUInfo *) {}
+  Parent::KernelType get_kernel(void) const override { return kernel; }
 };
 
 }  // namespace depthwise
 }  // namespace arm_conv
 
-#endif  // defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE) && defined(ARM_COMPUTE_ENABLE_SVE2)
+#endif  // defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE)
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_u8s8u8q_nhwc_5x5_s1_output2x2_mla_depthfirst/generic.cpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_u8s8u8q_nhwc_5x5_s1_output2x2_mla_depthfirst/generic.cpp
index 29582da..9cf95e9 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_u8s8u8q_nhwc_5x5_s1_output2x2_mla_depthfirst/generic.cpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_u8s8u8q_nhwc_5x5_s1_output2x2_mla_depthfirst/generic.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -27,7 +27,7 @@
 #include <cstddef>
 #include <cstdint>
 
-#if defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE) && defined(ARM_COMPUTE_ENABLE_SVE2)
+#if defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE)
 
 namespace arm_conv {
 namespace depthwise {
@@ -657,4 +657,4 @@
 }  // namespace depthwise
 }  // namespace arm_conv
 
-#endif  // defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE) && defined(ARM_COMPUTE_ENABLE_SVE2)
+#endif  // defined(__aarch64__) && defined(ARM_COMPUTE_ENABLE_SVE)
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/working_space.hpp b/src/core/NEON/kernels/arm_conv/depthwise/working_space.hpp
new file mode 100644
index 0000000..e9b29ca
--- /dev/null
+++ b/src/core/NEON/kernels/arm_conv/depthwise/working_space.hpp
@@ -0,0 +1,431 @@
+/*
+ * Copyright (c) 2022 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+/* Depthwise kernel drivers commonly require a per-thread blob of working space
+ * in which to store parameters required by the depthwise implementations. The
+ * composition of this working space varies with the driver, kernel, and data
+ * types -- but the tasks of requesting sufficient space, allocating buffer
+ * space, and performing initialisation of the working space are common.
+ *
+ * The classes in this file consist of a number of working space "Elements"
+ * (which are logical units of functionality) and a Workspace type which allows
+ * for compile time composition of elements into a single working space type.
+ *
+ * Creating a workspace
+ * ====================
+ *
+ * A new workspace type can be created by combining Elements as an argument to
+ * the Workspace class. For instance:
+ *
+ *   Workspace<
+ *     depthwise_depthfirst::InputArrayElement<float>,
+ *     InputBufferElement<float>,
+ *     OutputArrayElement<float>
+ *   >
+ *
+ * Creates a new Workspace consisting of the given elements. The workspace type
+ * contained within this class (`Workspace<...>::WorkspaceType`) is equivalent to:
+ *
+ *   struct WorkspaceType
+ *   {
+ *     const float **inptr_array;  // From InputArrayElement<float>
+ *     float *input_buffer;  // From InputBufferElement<float>
+ *     float **outptr_array;  // From OutputArrayElement<float>
+ *     float *output_buffer;  // From OutputArrayElement<float>
+ *   };
+ *
+ * Calling `Workspace<...>::get_sizeof_workspace(...)` will return the amount
+ * of space required to store the above struct and the elements contained
+ * within it. Once this space has been allocated, the workspace can be
+ * initialised by calling `Workspace<...>::initialise` with a pointer to the
+ * buffer and the same arguments. This will place a struct of type
+ * `Workspace<...>::WorkspaceType` at the start of the buffer, and share the
+ * remaining space between the specified elements. As this is all done at
+ * compile time, later code can access elements from the `WorkspaceType` by
+ * name.
+ *
+ * Writing a new element
+ * =====================
+ *
+ * Each Element must provide:
+ *  - A struct called "Workspace" containing the variables contained within
+ *    this portion of the workspace.
+ *  - A static method called `get_element_size` which returns the amount of
+ *    buffer space required by this element of the workspace (NOT including the
+ *    size of the Workspace struct). For example, an element which stores a
+ *    vector of pointers will return the amount of space required top store the
+ *    vector.
+ *  - A static method called `initialise` which accepts a pointer to a struct
+ *    which will be composed of the Element's `Workspace` struct (along with
+ *    other elements), a pointer to the start of the buffer allocated for this
+ *    portion of the workspace, and arguments to be used to initialise the
+ *    workspace. The Element should consume as much of the buffer as it
+ *    requires, initialise the Workspace, and then return the pointer to the
+ *    next free byte of the buffer.
+ *
+ * See the below elements for an example of how this should work.
+ */
+
+#pragma once
+
+#include "depthwise.hpp"
+#include "depthfirst_driver.hpp"
+#include "src/core/NEON/kernels/arm_gemm/utils.hpp"
+
+namespace arm_conv {
+namespace depthwise {
+namespace {  // anonymous because we expect this to appear in several compilation units
+
+/* Arguments to use to size and initialise a workspace.
+ */
+template <class StratType, class OutputStage=Nothing>
+struct WorkspaceArgs
+{
+  const StratType *strategy;
+  const DepthwiseArgs &depthwise_args;
+  const OutputStage &output_stage;
+
+  WorkspaceArgs(const StratType *strat, const DepthwiseArgs &dwargs, const OutputStage &os = {})
+  : strategy(strat), depthwise_args(dwargs), output_stage(os)
+  {
+  }
+};
+
+
+/* Sometimes we use templated structs to fill in workspace types, the Empty
+ * element can be useful for when a blank element is required for some sets of
+ * parameters.
+ */
+struct EmptyElement
+{
+  struct Workspace {};
+
+  template <class StratType, class OutputStage>
+  static size_t get_element_size(const WorkspaceArgs<StratType, OutputStage> &) { return 0; }
+
+  template <class WorkspaceType, class StratType, class OutputStage>
+  static void *initialise(WorkspaceType *, void *buffer, const WorkspaceArgs<StratType, OutputStage> &)
+  {
+    return buffer;
+  }
+};
+
+
+/* Store fused activations for a kernel.
+ *
+ * Activations are set based on the DepthwiseArgs.
+ */
+template <typename T, class OutputStage=Nothing>
+class ActivationsElement
+{
+  public:
+  struct Workspace
+  {
+    T activation_min, activation_max;
+  };
+
+  template <typename StratType>
+  static size_t get_element_size(const WorkspaceArgs<StratType, OutputStage> &)
+  {
+    return 0;
+  }
+
+  template <class WorkspaceType, class StratType>
+  static void *initialise(WorkspaceType *ws, void *buffer, const WorkspaceArgs<StratType, OutputStage> &args)
+  {
+    ws->activation_min = static_cast<T>(-std::numeric_limits<float>::infinity());
+    ws->activation_max = static_cast<T>(std::numeric_limits<float>::infinity());
+
+    switch (args.depthwise_args.activation.type)
+    {
+      case arm_gemm::Activation::Type::BoundedReLU:
+        ws->activation_max = static_cast<T>(args.depthwise_args.activation.param1);
+        // Fall through
+      case arm_gemm::Activation::Type::ReLU:
+        ws->activation_min = static_cast<T>(0);
+        break;
+      default:
+        break;
+    }
+
+    return buffer;
+  }
+};
+
+/* Activation clamps are contained within `arm_gemm::Requantize32`, so if the
+ * output stage is one of these we substitute in an empty workspace element.
+ */
+template <typename T>
+class ActivationsElement<T, arm_gemm::Requantize32> : public EmptyElement
+{
+};
+
+
+/* Get the value with which to fill an input buffer. This defaults to `0`
+ * (which we return as a `char` since it gets used by `memset`).
+ */
+template <typename OutputStage>
+char get_input_buffer_fill_value(const OutputStage &)
+{
+  return 0;
+}
+
+/* In the case of kernels operating on quantized data, we need to fill the
+ * input buffer with the zero offset of the input tensor.
+ */
+template <> char get_input_buffer_fill_value(const arm_gemm::Requantize32 &qp) __attribute__ ((unused));
+template <> char get_input_buffer_fill_value(const arm_gemm::Requantize32 &qp)
+{
+  return qp.a_offset;
+}
+
+
+/* Container for a vector of padding values which can be safely consumed by the
+ * depthwise kernel. The padding values are initialised to either `0` or the
+ * zero offset of the input tensor (if quantized).
+ */
+template <typename T>
+class InputBufferElement
+{
+  public:
+  struct Workspace
+  {
+    T *input_buffer;
+  };
+
+  template <typename StratType, typename OutputStage>
+  static size_t get_element_size(const WorkspaceArgs<StratType, OutputStage> &args)
+  {
+    return sizeof(T) * args.depthwise_args.input_channels;
+  }
+
+  template <class WorkspaceType, typename StratType, typename OutputStage>
+  static void *initialise(WorkspaceType *ws, void *buffer, const WorkspaceArgs<StratType, OutputStage> &args)
+  {
+    ws->input_buffer = reinterpret_cast<T*>(buffer);
+    memset(ws->input_buffer, get_input_buffer_fill_value(args.output_stage), get_element_size(args));
+    return reinterpret_cast<char *>(buffer) + get_element_size(args);
+  }
+};
+
+
+/* Container for an array of output pointers, and a buffer which can be used as
+ * a destination for unnecessary writes.
+ */
+template <typename T>
+class OutputArrayElement
+{
+  public:
+  struct Workspace
+  {
+    T **outptr_array;
+    T *output_buffer;
+  };
+
+  template <typename OutputStage>
+  static size_t get_element_size(const WorkspaceArgs<IDepthfirstStrategy, OutputStage> &args)
+  {
+    return sizeof_outptr_array(args) + sizeof_output_buffer(args);
+  }
+
+  template <class WorkspaceType, typename OutputStage>
+  static void *initialise(WorkspaceType *ws, void *buffer, const WorkspaceArgs<IDepthfirstStrategy, OutputStage> &args)
+  {
+    char *buffer_bytes = reinterpret_cast<char *>(buffer);
+
+    ws->outptr_array = reinterpret_cast<T **>(buffer_bytes);
+    buffer_bytes += sizeof_outptr_array(args);
+
+    ws->output_buffer = reinterpret_cast<T *>(buffer_bytes);
+    buffer_bytes += sizeof_output_buffer(args);
+
+    return buffer_bytes;
+  }
+
+  protected:
+  template <typename OutputStage>
+  static size_t sizeof_outptr_array(const WorkspaceArgs<IDepthfirstStrategy, OutputStage> &args)
+  {
+    return sizeof(T **) * args.strategy->get_output_rows() * args.strategy->get_output_cols();
+  }
+
+  template <typename OutputStage>
+  static size_t sizeof_output_buffer(const WorkspaceArgs<IDepthfirstStrategy, OutputStage> &args)
+  {
+    return sizeof(T) * args.depthwise_args.input_channels * args.depthwise_args.channel_multiplier;
+  }
+};
+
+
+/* Container for requantization parameters.
+ *
+ * This removes the distinction between per-layer and per-channel
+ * requantization parameters by providing a vector of requantization parameters
+ * regardless of whether per-layer or per-channel is selected.
+ */
+class RequantizationParametersElement
+{
+  public:
+  struct Workspace
+  {
+    const int32_t *bias, *requant_muls, *requant_shifts;
+  };
+
+  template <typename StratType>
+  static size_t get_element_size(const WorkspaceArgs<StratType, arm_gemm::Requantize32> &args)
+  {
+    return sizeof_bias(args) + sizeof_requant_muls(args) + sizeof_requant_shifts(args);
+  }
+
+  template <typename WorkspaceType, typename StratType>
+  static void *initialise(WorkspaceType *ws, void *buffer, const WorkspaceArgs<StratType, arm_gemm::Requantize32> &args)
+  {
+    const auto n_output_channels = args.depthwise_args.input_channels * args.depthwise_args.channel_multiplier;
+    char *buffer_bytes = reinterpret_cast<char *>(buffer);
+
+    ws->bias = args.output_stage.bias;
+    ws->requant_muls = args.output_stage.per_channel_muls;
+    ws->requant_shifts = args.output_stage.per_channel_right_shifts;
+
+    if (ws->bias == nullptr)
+    {
+      ws->bias = reinterpret_cast<const int32_t *>(buffer_bytes);
+      memset(buffer_bytes, 0, sizeof_bias(args));
+      buffer_bytes += sizeof_bias(args);
+    }
+
+    if (ws->requant_muls == nullptr)
+    {
+      ws->requant_muls = reinterpret_cast<const int32_t *>(buffer_bytes);
+      auto muls = reinterpret_cast<int32_t *>(buffer_bytes);
+      buffer_bytes += sizeof_requant_muls(args);
+
+      for (auto n = 0u; n < n_output_channels; n++)
+      {
+        muls[n] = args.output_stage.per_layer_mul;
+      }
+    }
+
+    if (ws->requant_shifts == nullptr)
+    {
+      ws->requant_shifts = reinterpret_cast<int32_t *>(buffer_bytes);
+      auto shifts = reinterpret_cast<int32_t *>(buffer_bytes);
+      buffer_bytes += sizeof_requant_shifts(args);
+
+      for (auto n = 0u; n < n_output_channels; n++)
+      {
+        shifts[n] = args.output_stage.per_layer_right_shift;
+      }
+    }
+
+    return buffer_bytes;
+  }
+
+  protected:
+  template <typename StratType>
+  static size_t sizeof_bias(const WorkspaceArgs<StratType, arm_gemm::Requantize32> &args)
+  {
+    return args.output_stage.bias != nullptr ?
+      0 : sizeof(int32_t) * args.depthwise_args.channel_multiplier * args.depthwise_args.input_channels;
+  }
+
+  template <typename StratType>
+  static size_t sizeof_requant_muls(const WorkspaceArgs<StratType, arm_gemm::Requantize32> &args)
+  {
+    return args.output_stage.per_channel_muls != nullptr ?
+      0 : sizeof(int32_t) * args.depthwise_args.channel_multiplier * args.depthwise_args.input_channels;
+  }
+
+  template <typename StratType>
+  static size_t sizeof_requant_shifts(const WorkspaceArgs<StratType, arm_gemm::Requantize32> &args)
+  {
+    return args.output_stage.per_channel_right_shifts != nullptr ?
+      0 : sizeof(int32_t) * args.depthwise_args.channel_multiplier * args.depthwise_args.input_channels;
+  }
+};
+
+
+template <typename ...Elements>
+class Workspace;
+
+template <typename Element, typename ...Elements>
+class Workspace<Element, Elements...>
+{
+  public:
+  struct WorkspaceType : Element::Workspace, Workspace<Elements...>::WorkspaceType
+  {
+  };
+
+  template <class S, class T>
+  static void initialise(void *buffer, const WorkspaceArgs<S, T> &args)
+  {
+    // Allocate sufficient space for the struct, then initialise each of the
+    // elements in turn.
+    auto ws = reinterpret_cast<WorkspaceType *>(buffer);
+    initialise_elements(ws, ws + 1, args);
+  }
+
+  template <class S, class T=Nothing>
+  static size_t get_sizeof_workspace(const WorkspaceArgs<S, T> &args)
+  {
+    return sizeof(WorkspaceType) + get_element_sizes(args);
+  }
+
+  template <class S, class T>
+  static inline size_t get_element_sizes(const WorkspaceArgs<S, T> &args)
+  {
+    return Element::get_element_size(args) + Workspace<Elements...>::get_element_sizes(args);
+  }
+
+  template <class WorkspaceType, class S, class T>
+  static void initialise_elements(WorkspaceType *ws, void *buffer, const WorkspaceArgs<S, T> &args)
+  {
+    buffer = Element::initialise(ws, buffer, args);  // Get the next buffer
+    Workspace<Elements...>::initialise_elements(ws, buffer, args);
+  }
+};
+
+template <>
+class Workspace<>
+{
+  public:
+  struct WorkspaceType
+  {
+  };
+
+  template <class S, class T>
+  static inline size_t get_element_sizes(const WorkspaceArgs<S, T> &)
+  {
+    return 0;
+  }
+
+  template <class WorkspaceType, class S, class T>
+  static void initialise_elements(WorkspaceType *, void *, const WorkspaceArgs<S, T> &)
+  {
+  }
+};
+
+}  // namespace {anonymous}
+}  // namespace depthwise
+}  // namespace arm_conv
diff --git a/src/core/NEON/kernels/assembly/depthwise.hpp b/src/core/NEON/kernels/assembly/depthwise.hpp
index 9262ea0..3998dfb 100644
--- a/src/core/NEON/kernels/assembly/depthwise.hpp
+++ b/src/core/NEON/kernels/assembly/depthwise.hpp
@@ -59,6 +59,8 @@
 
     const DepthwiseConfig *config;
 
+    bool fast_mode = false;
+
     DepthwiseArgs(
         const CPUInfo *cpu_info,
         unsigned int kernel_rows, unsigned int kernel_cols,
@@ -83,15 +85,18 @@
 
 protected:
     const DepthwiseArgs m_args; // Copy of arguments
+
 public:
     std::string name() const
     {
         return _name;
     }
+
     void set_name(const std::string &n)
     {
         _name = n;
     }
+
     DepthwiseCommon(const DepthwiseArgs &args)
         : m_args(args) {};
     DepthwiseCommon(DepthwiseCommon &) = delete;
@@ -103,7 +108,7 @@
         void *const        output,
         void *const        working_space,
         const unsigned int thread_id,
-        const unsigned int n_threads) const override
+        const unsigned int n_threads) const override final
     {
         const size_t ld_input_col    = m_args.input_channels;
         const size_t ld_input_row    = ld_input_col * m_args.input_cols;
@@ -130,7 +135,7 @@
         size_t             ld_output_batch,
         void *const        working_space,
         const unsigned int thread_id,
-        const unsigned int n_threads) const override
+        const unsigned int n_threads) const override final
     {
         execute(
             m_args.n_batches, m_args.input_rows, m_args.input_cols,
@@ -142,7 +147,36 @@
             working_space, thread_id, n_threads);
     }
 
-    virtual void execute(
+    void execute(
+        unsigned int         batches,
+        unsigned int         input_height,
+        unsigned int         input_width,
+        unsigned int         channels,
+        const PaddingValues &padding,
+        const void          *input,
+        size_t               ld_input_col,
+        size_t               ld_input_row,
+        size_t               ld_input_batch,
+        const void          *parameters,
+        unsigned int         output_height,
+        unsigned int         output_width,
+        void                *output,
+        size_t               ld_output_col,
+        size_t               ld_output_row,
+        size_t               ld_output_batch,
+        void                *working_space,
+        unsigned int         thread_id,
+        unsigned int         n_threads) const override final
+    {
+        this->execute_internal(
+            batches, input_height, input_width, channels, padding, input,
+            ld_input_col, ld_input_row, ld_input_batch, parameters, output_height,
+            output_width, output, ld_output_col, ld_output_row, ld_output_batch,
+            working_space, thread_id, n_threads);
+    }
+
+protected:
+    virtual void execute_internal(
         unsigned int batches,
         unsigned int input_height,
         unsigned int input_width,
@@ -161,7 +195,7 @@
         size_t       ld_output_batch,
         void        *working_space,
         unsigned int thread_id,
-        unsigned int n_threads) const override = 0;
+        unsigned int n_threads) const = 0;
 };
 
 template <typename TInput, typename TWeight = TInput, typename TOutput = TInput>