COMPMID-1975: Update depthwise convolution.

Change-Id: Iad58672be35710a7ec2e918653d6d529709387e8
Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com>
Reviewed-on: https://review.mlplatform.org/c/898
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Giuseppe Rossini <giuseppe.rossini@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
diff --git a/arm_compute/core/NEON/kernels/convolution/depthwise/impl_base.hpp b/arm_compute/core/NEON/kernels/convolution/depthwise/impl_base.hpp
index b33f276..674fc4d 100644
--- a/arm_compute/core/NEON/kernels/convolution/depthwise/impl_base.hpp
+++ b/arm_compute/core/NEON/kernels/convolution/depthwise/impl_base.hpp
@@ -31,101 +31,73 @@
  */
 
 #include <algorithm>
+#include <cstdint>
 #include "arm_compute/core/NEON/kernels/convolution/depthwise/depthwise.hpp"
+#include "arm_compute/core/NEON/kernels/convolution/common/padding.hpp"
 #include "arm_compute/core/NEON/kernels/convolution/common/utils.hpp"
 
 #pragma once
 
+#define MEMBERFN(TOUT) template <\
+  unsigned int OutputTileRows, unsigned int OutputTileColumns,\
+  unsigned int KernelRows, unsigned int KernelColumns,\
+  unsigned int StrideRows, unsigned int StrideColumns,\
+  typename TIn, typename TBias, typename TOut,\
+  typename Derived\
+> TOUT DepthwiseConvolutionBase<\
+  OutputTileRows, OutputTileColumns,\
+  KernelRows, KernelColumns,\
+  StrideRows, StrideColumns,\
+  TIn, TBias, TOut, Derived\
+>
+
+using namespace neon_convolution_kernels;
+
 namespace depthwise
 {
 
+template <unsigned int KernelRows, unsigned int KernelColumns, size_t WeightSize, size_t BiasSize>
+struct PackParameters
+{
+  static void execute(
+    unsigned int n_channels,
+    void *buffer,
+    const void *weights,
+    unsigned int weight_row_stride,
+    unsigned int weight_col_stride,
+    const void *biases
+  );
+};
+
 const unsigned int CHANNEL_BLOCK = 16;
 
-namespace
-{
-  inline int pad_along_dim(
-    const bool padding_same,
-    const int kernel_dim,
-    const int stride_dim,
-    const int input_dim
-  )
-  {
-    if (!padding_same)
-      return 0;
-    if (input_dim % stride_dim)
-      return std::max(kernel_dim - (input_dim % stride_dim), 0);
-    else
-      return std::max(kernel_dim - stride_dim, 0);
-  }
-}  // namespace
-
-template <int OTR, int OTC, int KR, int KC, int SR, int SC, typename TIn, typename TOut>
-int DepthwiseConvolution<OTR, OTC, KR, KC, SR, SC, TIn, TOut>::get_output_size(
-  const int dim_size, const bool same_padding
-)
-{
-  return iceildiv(dim_size - (same_padding ? 0 : (KC - 1)), SR);
-}
-
-template <int OTR, int OTC, int KR, int KC, int SR, int SC, typename TIn, typename TOut>
-int DepthwiseConvolution<OTR, OTC, KR, KC, SR, SC, TIn, TOut>::get_output_size(
+MEMBERFN(int)::get_output_size(
   const int dim_size, const unsigned int padding_before, const unsigned int padding_after
 )
 {
-  return iceildiv(dim_size + padding_before + padding_after - KR + 1, SR);
+  return iceildiv(dim_size + padding_before + padding_after - KernelRows + 1, StrideRows);
 }
 
-template <int OTR, int OTC, int KR, int KC, int SR, int SC, typename TIn, typename TOut>
-DepthwiseConvolution<OTR, OTC, KR, KC, SR, SC, TIn, TOut>::DepthwiseConvolution(
-  const int n_batches, const int n_input_rows, const int n_input_cols,
-  const int n_channels, const bool padding_same,
-  const TIn* const weights,
-  const TIn* const input,
-  TOut* const output,
-  const int weight_col_stride,
-  const int weight_row_stride,
-  const int input_col_stride,
-  const int input_row_stride,
-  const int input_batch_stride,
-  const int output_col_stride,
-  const int output_row_stride,
-  const int output_batch_stride
-) : DepthwiseConvolution<OTR, OTC, KR, KC, SR, SC, TIn, TOut>(
-  n_batches, n_input_rows, n_input_cols,
-  n_channels,
-  pad_along_dim(padding_same, KR, SR, n_input_rows) / 2,  /* top padding */
-  pad_along_dim(padding_same, KC, SC, n_input_cols) / 2,  /* left padding */
-  iceildiv(pad_along_dim(padding_same, KR, SR, n_input_rows), 2),  /* bottom padding */
-  iceildiv(pad_along_dim(padding_same, KC, SC, n_input_cols), 2),  /* right padding */
-  weights, input, output,
-  weight_col_stride, weight_row_stride,
-  input_col_stride, input_row_stride, input_batch_stride,
-  output_col_stride, output_row_stride, output_batch_stride
-)
+MEMBERFN(int)::output_size(
+  const int dim_size, const unsigned int padding_before, const unsigned int padding_after
+) const
 {
+  return get_output_size(dim_size, padding_before, padding_after);
 }
 
-
-template <int OTR, int OTC, int KR, int KC, int SR, int SC, typename TIn, typename TOut>
-DepthwiseConvolution<OTR, OTC, KR, KC, SR, SC, TIn, TOut>::DepthwiseConvolution(
-  const int n_batches, const int n_input_rows, const int n_input_cols,
+MEMBERFN()::DepthwiseConvolutionBase(
+  const int n_batches,
+  const int n_input_rows,
+  const int n_input_cols,
   const int n_channels,
+  ActivationFunction activation,
   const unsigned int padding_top,
   const unsigned int padding_left,
   const unsigned int padding_bottom,
-  const unsigned int padding_right,
-  const TIn* const weights,
-  const TIn* const input,
-  TOut* const output,
-  const int weight_col_stride,
-  const int weight_row_stride,
-  const int input_col_stride,
-  const int input_row_stride,
-  const int input_batch_stride,
-  const int output_col_stride,
-  const int output_row_stride,
-  const int output_batch_stride
-) : _weights(weights), _input(input), _output(output),
+  const unsigned int padding_right
+) : _input(nullptr), _output(nullptr),
+    _packed_parameters(nullptr),
+    _working_space(nullptr),
     _n_batches(n_batches),
     _n_input_rows(n_input_rows),
     _n_input_cols(n_input_cols),
@@ -138,37 +110,157 @@
     _padding_left(padding_left),
     _padding_bottom(padding_bottom),
     _padding_right(padding_right),
-    _weight_col_stride(weight_col_stride ? weight_col_stride : _n_channels),
-    _weight_row_stride(weight_row_stride ? weight_row_stride : KC * _weight_col_stride),
-    _input_col_stride(input_col_stride ? input_col_stride : _n_channels),
-    _input_row_stride(input_row_stride ? input_row_stride : _n_input_cols * _input_col_stride),
-    _input_batch_stride(input_batch_stride ? input_batch_stride : _n_input_rows * _input_row_stride),
-    _output_col_stride(output_col_stride ? output_col_stride : _n_channels),
-    _output_row_stride(output_row_stride ? output_row_stride : _n_output_cols * _output_col_stride),
-    _output_batch_stride(output_batch_stride ? output_batch_stride : _n_output_rows * _output_row_stride),
-    _input_offset(0), _weights_offset(0)
+    _activation(activation),
+    _input_col_stride(0), _input_row_stride(0), _input_batch_stride(0),
+    _input_ws_col_stride(_n_channels),
+    _input_ws_row_stride(_input_ws_col_stride * inner_tile_cols),
+    _output_col_stride(0), _output_row_stride(0), _output_batch_stride(0),
+    _output_ws_col_stride(_n_channels),
+    _output_ws_row_stride(_output_ws_col_stride * OutputTileColumns)
 {
 }
 
+MEMBERFN(void)::set_input(const void* const inptr)
+{
+  set_input(inptr, _n_channels);
+}
 
-template <int OTR, int OTC, int KR, int KC, int SR, int SC, typename TIn, typename TOut>
-unsigned int DepthwiseConvolution<OTR, OTC, KR, KC, SR, SC, TIn, TOut>::get_window() const
+MEMBERFN(void)::set_input(const void* const inptr, const int ld_col)
+{
+  set_input(inptr, _n_input_cols * ld_col, ld_col);
+}
+
+MEMBERFN(void)::set_input(const void* const inptr, const int ld_row, const int ld_col)
+{
+  set_input(inptr, _n_input_rows * ld_row, ld_row, ld_col);
+}
+
+MEMBERFN(void)::set_input(const void* const inptr, const int ld_batch, const int ld_row, const int ld_col)
+{
+  _input = static_cast<const TIn *>(inptr);
+  _input_batch_stride = ld_batch;
+  _input_row_stride = ld_row;
+  _input_col_stride = ld_col;
+}
+
+MEMBERFN(void)::set_output(void* const outptr)
+{
+  set_output(outptr, _n_channels);
+}
+
+MEMBERFN(void)::set_output(void* const outptr, const int ld_col)
+{
+  set_output(outptr, _n_output_cols * ld_col, ld_col);
+}
+
+MEMBERFN(void)::set_output(void* const outptr, const int ld_row, const int ld_col)
+{
+  set_output(outptr, _n_output_rows * ld_row, ld_row, ld_col);
+}
+
+MEMBERFN(void)::set_output(void* const outptr, const int ld_batch, const int ld_row, const int ld_col)
+{
+  _output = static_cast<TOut *>(outptr);
+  _output_batch_stride = ld_batch;
+  _output_row_stride = ld_row;
+  _output_col_stride = ld_col;
+}
+
+MEMBERFN(size_t)::get_packed_params_size(void) const
+{
+  return _n_channels * (sizeof(TIn)*KernelRows*KernelColumns + sizeof(TBias));
+}
+
+MEMBERFN(void)::set_packed_params_buffer(void *buffer)
+{
+  _packed_parameters = buffer;
+}
+
+MEMBERFN(void)::pack_params(const void *weights, const void *biases) const
+{
+  static_cast<const Derived *>(this)->pack_params(_packed_parameters, weights, biases);
+}
+
+MEMBERFN(void)::pack_params(void *buffer, const void *weights, const void *biases) const
+{
+  const unsigned int weight_col_stride = _n_channels;
+  const unsigned int weight_row_stride = KernelColumns * weight_col_stride;
+  static_cast<const Derived *>(this)->pack_params(
+    buffer, weights, weight_row_stride, weight_col_stride, biases
+  );
+}
+
+MEMBERFN(void)::pack_params(
+  void * const buffer,
+  const void * const weights,
+  const unsigned int weight_row_stride,
+  const unsigned int weight_col_stride,
+  const void * const biases
+) const
+{
+  static_cast<const Derived *>(this)->_pack_params(
+    buffer, weights, weight_row_stride, weight_col_stride, biases
+  );
+}
+
+MEMBERFN(void)::_pack_params(
+  void * const buffer,
+  const void * const weights,
+  const unsigned int weight_row_stride,
+  const unsigned int weight_col_stride,
+  const void * const biases
+) const
+{
+  // Default implementation
+  PackParameters<KernelRows, KernelColumns, sizeof(TIn), sizeof(TOut)>::execute(
+    _n_channels, buffer, weights, weight_row_stride, weight_col_stride, biases
+  );
+}
+
+MEMBERFN(size_t)::get_working_space_size(const unsigned int nthreads) const
+{
+  return nthreads * (
+    _get_input_working_space_size() + _get_output_working_space_size()
+  );
+}
+
+MEMBERFN(void)::set_working_space(void *buffer)
+{
+  _working_space = buffer;
+}
+
+MEMBERFN(size_t)::_get_input_working_space_size(void) const
+{
+  return sizeof(TIn) * inner_tile_rows * inner_tile_cols * _n_channels;
+}
+
+MEMBERFN(size_t)::_get_output_working_space_size(void) const
+{
+  return sizeof(TOut) * OutputTileRows * OutputTileColumns * _n_channels;
+}
+
+MEMBERFN(void *)::_get_input_working_space(const unsigned int threadid) const
+{
+  return static_cast<uint8_t*>(_working_space) + threadid * (
+    _get_input_working_space_size() + _get_output_working_space_size()
+  );
+}
+
+MEMBERFN(void *)::_get_output_working_space(const unsigned int threadid) const
+{
+  return static_cast<uint8_t*>(_get_input_working_space(threadid)) + _get_input_working_space_size();
+}
+
+MEMBERFN(unsigned int)::get_window() const
 {
   // Parallelise over blocks of channels.
   return iceildiv(_n_channels, CHANNEL_BLOCK);
 }
 
-template <int OTR, int OTC, int KR, int KC, int SR, int SC, typename TIn, typename TOut>
-void DepthwiseConvolution<OTR, OTC, KR, KC, SR, SC, TIn, TOut>::set_offsets(int input_offset, int weights_offset)
-{
-    _input_offset = input_offset;
-    _weights_offset = weights_offset;
-}
-
-template <int OTR, int OTC, int KR, int KC, int SR, int SC, typename TIn, typename TOut>
-void DepthwiseConvolution<OTR, OTC, KR, KC, SR, SC, TIn, TOut>::run(
+MEMBERFN(void)::run(
   const unsigned int start,
-  const unsigned int stop
+  const unsigned int stop,
+  const unsigned int threadid
 )
 {
   // Parallelise over blocks of channels
@@ -205,43 +297,38 @@
       const int output_row_bottom = (tile_i + 1)*output_tile_rows;
       const int output_row_pad_bottom = std::max(0, output_row_bottom - _n_output_rows);
 
+      // Get the offset into the packed parameters
+      const auto params_ptr = static_cast<const uint8_t*>(_packed_parameters) +
+        start_channel*(sizeof(TIn)*KernelRows*KernelColumns + sizeof(TBias));
+
       // Process the row
       process_tile_row(
+        threadid,
         stop_channel - start_channel,
-        _weights + start_channel, _weight_row_stride, _weight_col_stride,
-        inptr_row + start_channel, _input_row_stride, _input_col_stride,
-        outptr_row + start_channel, _output_row_stride, _output_col_stride,
+        params_ptr,
+        inptr_row + start_channel,
+        outptr_row + start_channel,
         input_row_pad_top, input_pad_left, input_row_pad_bottom,
         output_row_pad_bottom,
-        _n_tile_cols, _n_input_cols, _n_output_cols,
-        _input_offset, _weights_offset
+        _n_tile_cols, _n_input_cols, _n_output_cols
       );
     }
   }
 }
 
-
-template <int OTR, int OTC, int KR, int KC, int SR, int SC, typename TIn, typename TOut>
-void DepthwiseConvolution<OTR, OTC, KR, KC, SR, SC, TIn, TOut>::process_tile_row(
+MEMBERFN(void)::process_tile_row(
+  const unsigned int threadid,
   const int n_channels,
-  const TIn* const weights,
-  const int weight_row_stride,
-  const int weight_col_stride,
+  const void* const packed_params,
   const TIn* const inptr,
-  const int in_row_stride,
-  const int in_col_stride,
   TOut* const outptr,
-  const int out_row_stride,
-  const int out_col_stride,
   const int row_pad_in_top,
   const int row_pad_in_left,
   const int row_pad_in_bottom,
   const int row_pad_out_bottom,
   const int n_tiles,
   const int n_input_cols,
-  const int n_output_cols,
-  const int input_offset,
-  const int weights_offset
+  const int n_output_cols
 )
 {
   constexpr int tile_overlap = kernel_cols - stride_cols;
@@ -261,264 +348,97 @@
 
     // Get pointers into the inputs and outputs
     const int col_offset = (tile_j == 0) ? 0 : row_pad_in_left;
-    const TIn* const inptr_col = (inptr + ((inner_tile_cols - tile_overlap)*tile_j - col_offset)*in_col_stride);
-    TOut* const outptr_col = outptr + tile_j * output_tile_cols * out_col_stride;
+    const TIn* const inptr_col = (inptr + ((inner_tile_cols - tile_overlap)*tile_j - col_offset)*_input_col_stride);
+    TOut* const outptr_col = outptr + tile_j * output_tile_cols * _output_col_stride;
 
-    // Apply the specific tile processing function
-    const bool pad_top = row_pad_in_top > 0;
-    const bool pad_left = t_pad_in_left > 0;
-    const bool pad_bottom = row_pad_in_bottom || row_pad_out_bottom;
-    const bool pad_right = t_pad_in_right || t_pad_out_right;
-
-    const TileFn tilefn = [&] () {
-      if (!pad_top && !pad_left && !pad_bottom && !pad_right)
-      {
-        // No padding
-        return tilefn_unpadded;
-      }
-      else if (pad_top && !pad_left && !pad_bottom && !pad_right)
-      {
-        // Padding on the top only, subtract off the minimum expected padding in
-        // order to index into the array of specialised methods.
-        const int index = row_pad_in_top - min_in_pad_top;
-        return tilefn_top[index];
-      }
-      else if (!pad_top && pad_left && !pad_bottom && !pad_right)
-      {
-        // Padding on the left only, subtract off the minimum expected padding in
-        // order to index into the array of specialised methods.
-        const int index = t_pad_in_left - min_in_pad_left;
-        return tilefn_left[index];
-      }
-      else if (!pad_top && !pad_left && pad_bottom && !pad_right)
-      {
-        // Padding on the bottom only
-        return tilefn_bottom[row_pad_in_bottom][row_pad_out_bottom];
-      }
-      else if (!pad_top && !pad_left && !pad_bottom && pad_right)
-      {
-        // Padding on the right only
-        return tilefn_right[t_pad_in_right][t_pad_out_right];
-      }
-      else
-      {
-        // Otherwise use generic tile processing method.
-        return tilefn_generic;
-      }
-    }();
-
-    tilefn(
-      n_channels,
-      weights, weight_row_stride, weight_col_stride,
-      inptr_col, in_row_stride, in_col_stride,
-      outptr_col, out_row_stride, out_col_stride,
-      row_pad_in_top, t_pad_in_left, row_pad_in_bottom, t_pad_in_right,
-      row_pad_out_bottom, t_pad_out_right, input_offset, weights_offset
+    // Process just this tile
+    process_tile(
+      threadid, n_channels, packed_params, inptr_col, outptr_col,
+      row_pad_in_top, t_pad_in_left, row_pad_in_bottom, t_pad_in_right,  // Input paddings
+      row_pad_out_bottom, t_pad_out_right  // Output paddings
     );
   }
 }
 
-
-// New templated struct used solely as a way to provide tile processing
-// specialisations.
-template <int OutputTileRows, int OutputTileCols,
-          int KernelRows, int KernelCols,
-          int StrideRows, int StrideCols,
-          typename TIn, typename TOut>
-struct DepthwiseConvolutionImpl : public DepthwiseConvolution<
-    OutputTileRows, OutputTileCols,
-    KernelRows, KernelCols,
-    StrideRows, StrideCols, TIn, TOut
->
+MEMBERFN(TIn)::_input_padding_value(void) const
 {
-  typedef DepthwiseConvolution<
-    OutputTileRows, OutputTileCols,
-    KernelRows, KernelCols,
-    StrideRows, StrideCols,
-    TIn, TOut
-  > DWC;
+  return static_cast<TIn>(0);
+}
 
-  /** Perform the depthwise convolution of a tile.
-   *
-   * @param[in] n_channels Number of channels.
-   * @param[in] weights Pointer to Height x Width x Channels ordered weights.
-   * @param[in] inptr Pointer to the top-left unpadded value of the tile.
-   * @param[in] in_row_stride Stride between rows of the input tensor.
-   * @param[in] in_col_stride Stride between columns of the input tensor.
-   * @param[out] outptr Pointer to the top-left output value for the tile.
-   * @param[in] out_row_stride Stride between rows of the output tensor.
-   * @param[in] out_col_stride Stride between columns of the output tensor.
-   *
-   * The following parameters may be ignored if the function has been
-   * specialised for specific padding constraints.
-   *
-   * @param[in] _in_pad_top Padding to apply to top of input tile.
-   * @param[in] _in_pad_left Padding to apply to left of input tile.
-   * @param[in] _in_pad_bottom Padding to apply to bottom of input tile.
-   * @param[in] _in_pad_right Padding to apply to right of input tile.
-   * @param[in] _out_pad_bottom Null cells at bottom of output tile.
-   * @param[in] _out_pad_right Null cells at right of output tile.
-   */
-  template <
-    bool Specialize=false,  // Specialize (or not) the method
-    int InPadTop=0,         // If specialized, top padding
-    int InPadLeft=0,        // If specialized, left padding
-    int InPadBottom=0,      // If specialized, bottom padding
-    int InPadRight=0,       // If specialized, right padding
-    int OutPadBottom=0,     // If specialized, bottom output padding
-    int OutPadRight=0       // If specialized, bottom right padding
-  >
-  static void process_tile(
-    const int n_channels,
-    const TIn* const weights,
-    const int weight_row_stride,
-    const int weight_col_stride,
-    const TIn* const inptr,
-    const int in_row_stride,
-    const int in_col_stride,
-    TOut* const outptr,
-    const int out_row_stride,
-    const int out_col_stride,
-    const int in_pad_top=0,
-    const int in_pad_left=0,
-    const int in_pad_bottom=0,
-    const int in_pad_right=0,
-    const int out_pad_bottom=0,
-    const int out_pad_right=0,
-    const int input_offset=0,
-    const int weights_offset=0
-  );
-};
-
-
-template <int OTR, int OTC, int KR, int KC, int SR, int SC, typename TIn, typename TOut>
-template <
-  bool Specialize,
-  int InPadTop, int InPadLeft, int InPadBottom, int InPadRight,
-  int OutPadBottom, int OutPadRight
->
-void DepthwiseConvolutionImpl<OTR, OTC, KR, KC, SR, SC, TIn, TOut>::process_tile(
+MEMBERFN(void)::process_tile(
+  const unsigned int threadid,
   const int n_channels,
-  const TIn *__restrict__ const weights,
-  const int weight_row_stride,
-  const int weight_col_stride,
-  const TIn *__restrict__ const inptr,
-  const int in_row_stride,
-  const int in_col_stride,
-  TOut *__restrict__ const outptr,
-  const int out_row_stride,
-  const int out_col_stride,
-  const int _in_pad_top,
-  const int _in_pad_left,
-  const int _in_pad_bottom,
-  const int _in_pad_right,
-  const int _out_pad_bottom,
-  const int _out_pad_right,
-  const int _input_offset,
-  const int _weights_offset
+  const void* const packed_params,
+  const TIn* const inptr,
+  TOut* const outptr,
+  const int pad_in_top,
+  const int pad_in_left,
+  const int pad_in_bottom,
+  const int pad_in_right,
+  const int pad_out_bottom,
+  const int pad_out_right
 )
 {
-  constexpr auto inner_tile_rows = DWC::inner_tile_rows;
-  constexpr auto inner_tile_cols = DWC::inner_tile_cols;
-  constexpr auto kernel_rows = DWC::kernel_rows;
-  constexpr auto kernel_cols = DWC::kernel_cols;
-  constexpr auto output_tile_rows = DWC::output_tile_rows;
-  constexpr auto output_tile_cols = DWC::output_tile_cols;
-  constexpr auto stride_rows = DWC::stride_rows;
-  constexpr auto stride_cols = DWC::stride_cols;
+  const bool pad_input = pad_in_top || pad_in_left || pad_in_bottom || pad_in_right;
+  const bool pad_output = pad_out_bottom || pad_out_right;
 
-  // Extract parameters
-  const int in_pad_top = Specialize ? InPadTop : _in_pad_top;
-  const int in_pad_left = Specialize ? InPadLeft : _in_pad_left;
-  const int in_pad_bottom = Specialize ? InPadBottom : _in_pad_bottom;
-  const int in_pad_right = Specialize ? InPadRight : _in_pad_right;
-  const int out_pad_bottom = Specialize ? OutPadBottom : _out_pad_bottom;
-  const int out_pad_right = Specialize ? OutPadRight : _out_pad_right;
-
-  // Compute valid ranges of the tile
-  const int in_cells_i = inner_tile_rows - in_pad_bottom;
-  const int in_cells_j = inner_tile_cols - in_pad_right;
-  const int out_cells_i = output_tile_rows - out_pad_bottom;
-  const int out_cells_j = output_tile_cols - out_pad_right;
-
-  // Instantiate pointers
-  const TIn* __restrict__ inptr_base = inptr;
-  const TIn* __restrict__ wptr_base = weights;
-  TOut* __restrict__ outptr_base = outptr;
-
-  // Perform the depthwise convolution
-  int channels_remaining = n_channels;
-  for (; channels_remaining; channels_remaining--)
+  if (pad_input)
   {
-    // Load input tile
-    TIn u[inner_tile_rows][inner_tile_cols];
-    for (int i = 0; i < inner_tile_rows; i++)
-    {
-      const TIn* const inptr_row = inptr_base + (i - in_pad_top)*in_row_stride;
-      for (int j = 0; j < inner_tile_cols; j++)
-      {
-        if (i < in_pad_top || in_cells_i <= i ||
-            j < in_pad_left || in_cells_j <= j)
-        {
-          u[i][j] = static_cast<TIn>(0);
-        }
-        else
-        {
-          u[i][j] = *(inptr_row + (j - in_pad_left)*in_col_stride);
-        }
-      }
-    }
-    inptr_base++;
-
-    // Load weights tile
-    TIn w[kernel_rows][kernel_cols];
-    for (int i = 0; i < kernel_rows; i++)
-    {
-      const TIn* const wptr_row = wptr_base + i*weight_row_stride;
-      for (int j = 0; j < kernel_cols; j++)
-      {
-        w[i][j] = *(wptr_row + j*weight_col_stride);
-      }
-    }
-    wptr_base++;
-
-    // Perform the convolution
-    TOut v[output_tile_rows][output_tile_cols];
-    for (int out_i = 0; out_i < out_cells_i; out_i++)
-    {
-      for (int out_j = 0; out_j < out_cells_j; out_j++)
-      {
-        // Clear the accumulator
-        v[out_i][out_j] = static_cast<TOut>(0);
-
-        // Base co-ordinate
-        const int base_i = out_i * stride_rows;
-        const int base_j = out_j * stride_cols;
-
-        // Fill the accumulator
-        for (int in_i = 0; in_i < kernel_rows; in_i++)
-        {
-          const int i = base_i + in_i;
-          for (int in_j = 0; in_j < kernel_cols; in_j++)
-          {
-            const int j = base_j + in_j;
-            v[out_i][out_j] += w[in_i][in_j] * u[i][j];
-          }
-        }
-      }
-    }
-
-    // Store the output tile
-    for (int i = 0; i < out_cells_i; i++)
-    {
-      TOut* __restrict__ const outptr_row = outptr_base + i*out_row_stride;
-      for (int j = 0; j < out_cells_j; j++)
-      {
-        *(outptr_row + j*out_col_stride) = v[i][j];
-      }
-    }
-    outptr_base++;
+    // Copy the input into the temporary buffer, applying padding
+    padding::copy_and_pad_tile<TIn>(
+      inner_tile_rows, inner_tile_cols, n_channels,
+      inptr, _input_row_stride, _input_col_stride,
+      static_cast<TIn *>(_get_input_working_space(threadid)), _input_ws_row_stride, _input_ws_col_stride,
+      pad_in_top, pad_in_left, pad_in_bottom, pad_in_right,
+      static_cast<Derived *>(this)->_input_padding_value()
+    );
   }
+
+  // Execute the kernel
+  const TIn * const tile_inptr = !pad_input ? inptr : static_cast<const TIn *>(_get_input_working_space(threadid));
+  const int in_row_stride = !pad_input ? _input_row_stride : _input_ws_row_stride;
+  const int in_col_stride = !pad_input ? _input_col_stride : _input_ws_col_stride;
+
+  TOut * const tile_outptr = !pad_output ? outptr : static_cast<TOut *>(_get_output_working_space(threadid));
+  const int out_row_stride = !pad_output ? _output_row_stride : _output_ws_row_stride;
+  const int out_col_stride = !pad_output ? _output_col_stride : _output_ws_col_stride;
+
+  Derived * dthis = static_cast<Derived *>(this);
+
+  switch(_activation)
+  {
+    case ActivationFunction::ReLU:
+      dthis->template execute_tile<ActivationFunction::ReLU>(
+        n_channels, packed_params, tile_inptr, in_row_stride, in_col_stride, tile_outptr, out_row_stride, out_col_stride
+      );
+      break;
+    case ActivationFunction::ReLU6:
+      dthis->template execute_tile<ActivationFunction::ReLU6>(
+        n_channels, packed_params, tile_inptr, in_row_stride, in_col_stride, tile_outptr, out_row_stride, out_col_stride
+      );
+      break;
+    default:
+      dthis->template execute_tile<ActivationFunction::None>(
+        n_channels, packed_params, tile_inptr, in_row_stride, in_col_stride, tile_outptr, out_row_stride, out_col_stride
+      );
+      break;
+  }
+
+  if (pad_output)
+  {
+    // Copy the output from the temporary buffer, removing unnecessary values
+    padding::CopyCropped<OutputTileRows, OutputTileColumns>::execute(
+      n_channels * sizeof(TOut),
+      _get_output_working_space(threadid), _output_ws_row_stride * sizeof(TOut), _output_ws_col_stride * sizeof(TOut),
+      outptr, _output_row_stride * sizeof(TOut), _output_col_stride * sizeof(TOut),
+      0, 0, pad_out_bottom, pad_out_right
+    );
+  }
+}
+
+MEMBERFN(int)::n_channels(void) const
+{
+  return _n_channels;
 }
 
 }  // namespace depthwise