COMPMID-873: Integrate RSH NEON Depthwise Convolution routine

Change-Id: Ida1e9a836bc518bfe5563e16bf7f92bde5fc13f7
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/118472
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Pablo Tello <pablo.tello@arm.com>
diff --git a/arm_compute/core/NEON/kernels/NEDepthwiseConvolutionLayer3x3Kernel.h b/arm_compute/core/NEON/kernels/NEDepthwiseConvolutionLayer3x3Kernel.h
index a441fb4..1367f37 100644
--- a/arm_compute/core/NEON/kernels/NEDepthwiseConvolutionLayer3x3Kernel.h
+++ b/arm_compute/core/NEON/kernels/NEDepthwiseConvolutionLayer3x3Kernel.h
@@ -25,13 +25,15 @@
 #define __ARM_COMPUTE_NEDEPTHWISECONVOLUTIONKERNEL3x3_H__
 
 #include "arm_compute/core/NEON/INEKernel.h"
+#include "arm_compute/core/NEON/kernels/convolution/depthwise/depthwise.hpp"
+
+#include <memory>
 
 namespace arm_compute
 {
 class ITensor;
 
-/** Interface for the kernel to run a 3x3 depthwise convolution on a tensor.
- */
+/** Interface for the kernel to run a 3x3 depthwise convolution on a tensor. */
 class NEDepthwiseConvolutionLayer3x3Kernel : public INEKernel
 {
 public:
@@ -51,24 +53,47 @@
     NEDepthwiseConvolutionLayer3x3Kernel &operator=(NEDepthwiseConvolutionLayer3x3Kernel &&) = default;
     /** Initialize the function's source, destination, conv and border_size.
      *
-     * @param[in]  input     Source tensor. DataType supported: QASYMM8, F32.
-     * @param[in]  weights   Weights tensor. This is a 3D tensor with dimensions [3, 3, IFM]. Data type supported: Same as @p input.
-     * @param[out] output    Destination tensor. Data type supported: Same as @p input.
-     * @param[in]  conv_info Padding and stride information to use for the convolution.
+     * @param[in]  input       Source tensor. DataType supported: QASYMM8, F32.
+     * @param[in]  weights     Weights tensor. This is a 3D tensor with dimensions [3, 3, IFM]. Data type supported: Same as @p input.
+     * @param[out] output      Destination tensor. Data type supported: Same as @p input.
+     * @param[in]  conv_info   Padding and stride information to use for the convolution.
+     * @param[in]  data_layout (Optional) Data layout of the input and weights tensor
      */
-    void configure(const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info);
+    void configure(const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info, DataLayout data_layout = DataLayout::NCHW);
+    /** Static method that checks if optimized execution is supported for the given parameters
+     *
+     * @param[in] input_shape Input shape
+     * @param[in] conv_info   Padding and stride information to use for the convolution.
+     * @param[in] dt          Data type of the input and weights
+     * @param[in] data_layout (Optional) Data layout of the input and weights tensor
+     *
+     * @return True if the optimized kernels can be executed else false
+     */
+    static bool is_optimized_execution_possible(TensorShape input_shape, PadStrideInfo conv_info, DataType dt, DataLayout data_layout = DataLayout::NCHW);
+    /** Generates the convolver object */
+    void generate_convolver();
 
     // Inherited methods overridden:
     void run(const Window &window, const ThreadInfo &info) override;
     BorderSize border_size() const override;
 
 private:
-    BorderSize     _border_size;
-    const ITensor *_input;
-    ITensor       *_output;
-    const ITensor *_weights;
-    PadStrideInfo  _conv_info;
-    unsigned int   _num_elems_written_per_iteration;
+    void configure_generic();
+    void configure_optimized();
+    void run_generic(const Window &window, const ThreadInfo &info);
+    void run_optimized(const Window &window, const ThreadInfo &info);
+    std::unique_ptr<depthwise::IDepthwiseConvolution> create_convolver_object(TensorShape shape, PadStrideInfo conv_info,
+                                                                              const uint8_t *w_ptr, uint8_t *in_ptr, uint8_t *out_ptr);
+
+private:
+    BorderSize                                        _border_size;
+    const ITensor                                    *_input;
+    ITensor                                          *_output;
+    const ITensor                                    *_weights;
+    PadStrideInfo                                     _conv_info;
+    std::unique_ptr<depthwise::IDepthwiseConvolution> _convolver;
+    unsigned int                                      _num_elems_written_per_iteration;
+    bool                                              _run_optimized;
 };
 } // namespace arm_compute
 #endif /* __ARM_COMPUTE_NEDEPTHWISECONVOLUTIONKERNEL3x3_H__ */
diff --git a/arm_compute/core/NEON/kernels/NEWinogradLayerKernel.h b/arm_compute/core/NEON/kernels/NEWinogradLayerKernel.h
index 97532f3..a8645dc 100644
--- a/arm_compute/core/NEON/kernels/NEWinogradLayerKernel.h
+++ b/arm_compute/core/NEON/kernels/NEWinogradLayerKernel.h
@@ -25,10 +25,10 @@
 #define __ARM_COMPUTE_NEGEMMWINOGRADLAYERKERNEL_H__
 
 #include "arm_compute/core/NEON/INEKernel.h"
-#include "arm_compute/core/NEON/kernels/winograd/batched_blocked_gemm.hpp"
-#include "arm_compute/core/NEON/kernels/winograd/convolution.hpp"
-#include "arm_compute/core/NEON/kernels/winograd/tensor.hpp"
-#include "arm_compute/core/NEON/kernels/winograd/winograd_gemm.hpp"
+#include "arm_compute/core/NEON/kernels/convolution/common/convolution.hpp"
+#include "arm_compute/core/NEON/kernels/convolution/common/tensor.hpp"
+#include "arm_compute/core/NEON/kernels/convolution/winograd/batched_blocked_gemm.hpp"
+#include "arm_compute/core/NEON/kernels/convolution/winograd/winograd_gemm.hpp"
 
 namespace arm_compute
 {
diff --git a/arm_compute/core/NEON/kernels/winograd/alloc.hpp b/arm_compute/core/NEON/kernels/convolution/common/alloc.hpp
similarity index 100%
rename from arm_compute/core/NEON/kernels/winograd/alloc.hpp
rename to arm_compute/core/NEON/kernels/convolution/common/alloc.hpp
diff --git a/arm_compute/core/NEON/kernels/winograd/arm.hpp b/arm_compute/core/NEON/kernels/convolution/common/arm.hpp
similarity index 100%
rename from arm_compute/core/NEON/kernels/winograd/arm.hpp
rename to arm_compute/core/NEON/kernels/convolution/common/arm.hpp
diff --git a/arm_compute/core/NEON/kernels/winograd/convolution.hpp b/arm_compute/core/NEON/kernels/convolution/common/convolution.hpp
similarity index 100%
rename from arm_compute/core/NEON/kernels/winograd/convolution.hpp
rename to arm_compute/core/NEON/kernels/convolution/common/convolution.hpp
diff --git a/arm_compute/core/NEON/kernels/winograd/perf.h b/arm_compute/core/NEON/kernels/convolution/common/perf.h
similarity index 100%
rename from arm_compute/core/NEON/kernels/winograd/perf.h
rename to arm_compute/core/NEON/kernels/convolution/common/perf.h
diff --git a/arm_compute/core/NEON/kernels/winograd/profiler.hpp b/arm_compute/core/NEON/kernels/convolution/common/profiler.hpp
similarity index 100%
rename from arm_compute/core/NEON/kernels/winograd/profiler.hpp
rename to arm_compute/core/NEON/kernels/convolution/common/profiler.hpp
diff --git a/arm_compute/core/NEON/kernels/winograd/shims.hpp b/arm_compute/core/NEON/kernels/convolution/common/shims.hpp
similarity index 100%
rename from arm_compute/core/NEON/kernels/winograd/shims.hpp
rename to arm_compute/core/NEON/kernels/convolution/common/shims.hpp
diff --git a/arm_compute/core/NEON/kernels/winograd/tensor.hpp b/arm_compute/core/NEON/kernels/convolution/common/tensor.hpp
similarity index 100%
rename from arm_compute/core/NEON/kernels/winograd/tensor.hpp
rename to arm_compute/core/NEON/kernels/convolution/common/tensor.hpp
diff --git a/arm_compute/core/NEON/kernels/winograd/tensor_utils.hpp b/arm_compute/core/NEON/kernels/convolution/common/tensor_utils.hpp
similarity index 100%
rename from arm_compute/core/NEON/kernels/winograd/tensor_utils.hpp
rename to arm_compute/core/NEON/kernels/convolution/common/tensor_utils.hpp
diff --git a/arm_compute/core/NEON/kernels/winograd/utils.hpp b/arm_compute/core/NEON/kernels/convolution/common/utils.hpp
similarity index 100%
rename from arm_compute/core/NEON/kernels/winograd/utils.hpp
rename to arm_compute/core/NEON/kernels/convolution/common/utils.hpp
diff --git a/arm_compute/core/NEON/kernels/convolution/depthwise/depthwise.hpp b/arm_compute/core/NEON/kernels/convolution/depthwise/depthwise.hpp
new file mode 100644
index 0000000..80b0614
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/convolution/depthwise/depthwise.hpp
@@ -0,0 +1,209 @@
+/*
+ * Copyright (c) 2018 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+
+namespace depthwise
+{
+
+class IDepthwiseConvolution
+{
+public:
+    virtual ~IDepthwiseConvolution() = default;
+    virtual int output_size(const int dim_size, const bool padding_same) const = 0;
+    virtual unsigned int get_window(void) const = 0;
+    virtual void run(const unsigned int start, const unsigned int stop) = 0;
+};
+
+template <
+  int OutputTileRows,
+  int OutputTileCols,
+  int KernelRows,
+  int KernelCols,
+  int StrideRows,
+  int StrideCols,
+  typename TIn,
+  typename TOut
+>
+class DepthwiseConvolution : public IDepthwiseConvolution
+{
+  public:
+    typedef TIn InputType;
+    typedef TOut OutputType;
+
+    // Information about the specific convolution instance
+    static constexpr int output_tile_rows = OutputTileRows;
+    static constexpr int output_tile_cols = OutputTileCols;
+    static constexpr int kernel_rows = KernelRows;
+    static constexpr int kernel_cols = KernelCols;
+    static constexpr int stride_rows = StrideRows;
+    static constexpr int stride_cols = StrideCols;
+    static constexpr int inner_tile_rows = stride_rows * output_tile_rows + kernel_rows - 1;
+    static constexpr int inner_tile_cols = stride_cols * output_tile_cols + kernel_cols - 1;
+
+    /** Create a new depthwise convolution engine.
+     *
+     * @param[in] n_batches Number of batches tensors.
+     * @param[in] n_input_rows Number of rows in input tensor.
+     * @param[in] n_input_cols Number of columns in input tensor.
+     * @param[in] n_channels Number of channels in input and output tensors.
+     * @param[in] padding_same True if padding is SAME, else VALID.
+     * @param[in] weights Pointer to Height x Width x Channel ordered weights.
+     * @param[in] input Pointer to NHWC ordered input tensor.
+     * @param[output] output Pointer to NHWC ordered output tensor.
+     */
+    DepthwiseConvolution(
+      const int n_batches, const int n_input_rows, const int n_input_cols,
+      const int n_channels, const bool padding_same,
+      const TIn* const weights,
+      const TIn* const input,
+      TOut* const output
+    );
+
+    // Cannot copy or move a DepthwiseConvolution.
+    DepthwiseConvolution(DepthwiseConvolution&) = delete;
+    DepthwiseConvolution operator=(DepthwiseConvolution&) = delete;
+
+    /** Get the number of output rows/columns.
+     *
+     * @param[in] dim_size Number of elements in the dimension (rows/columns)
+     * @param[in] same_padding True if the padding is SAME, otherwise false.
+     */
+    static int get_output_size(const int dim_size, const bool padding_same);
+
+    /** Get the number of output rows/columns.
+     *
+     * @param[in] dim_size Number of elements in the dimension (rows/columns)
+     * @param[in] same_padding True if the padding is SAME, otherwise false.
+     */
+    int output_size(const int dim_size, const bool padding_same) const override
+    {
+        return DepthwiseConvolution<OutputTileRows,
+                                    OutputTileCols,
+                                    KernelRows,
+                                    KernelCols,
+                                    StrideRows,
+                                    StrideCols,
+                                    TIn,
+                                    TOut>::get_output_size(dim_size, padding_same);
+    }
+
+    /** Get the window of work to be performed by an instance of the operator.
+     */
+    unsigned int get_window(void) const override;
+
+    /** Perform a portion of the work associated with the operator.
+     *
+     * Will perform the window of work described by $[start, stop)$.
+     *
+     * @param[in] start Start of the window of work to perform.
+     * @param[in] stop End of the work to perform.
+     */
+    void run(const unsigned int start, const unsigned int stop) override;
+
+  protected:
+    /** Process a tile-row of the tensors.
+     */
+    static void process_tile_row(
+      const int n_channels,
+      const TIn* const weights,
+      const TIn* const inptr,
+      const int in_row_stride,
+      const int in_col_stride,
+      TOut* const outptr,
+      const int out_row_stride,
+      const int out_col_stride,
+      const int row_pad_in_top,
+      const int row_pad_in_left,
+      const int row_pad_in_bottom,
+      const int row_pad_out_bottom,
+      const int n_tiles,
+      const int n_input_cols,
+      const int n_output_cols
+    );
+
+    /** Process a single tile of the tensors.
+     *
+     * @param[in] n_channels Number of channels.
+     * @param[in] weights Pointer to Height x Width x Channels ordered weights.
+     * @param[in] inptr Pointer to the top-left unpadded value of the tile.
+     * @param[in] in_row_stride Stride between rows of the input tensor.
+     * @param[in] in_col_stride Stride between columns of the input tensor.
+     * @param[out] outptr Pointer to the top-left output value for the tile.
+     * @param[in] out_row_stride Stride between rows of the output tensor.
+     * @param[in] out_col_stride Stride between columns of the output tensor.
+     */
+    template <
+      int in_pad_top, int in_pad_left, int in_pad_bottom, int in_pad_right,
+      int out_pad_bottom, int out_pad_right
+    >
+    static void process_tile(
+      const int n_channels,
+      const TIn* const weights,
+      const TIn* const inptr,
+      const int in_row_stride,
+      const int in_col_stride,
+      TOut* const outptr,
+      const int out_row_stride,
+      const int out_col_stride
+    );
+
+    // Type of a pointer to a `process_tile` instance
+    typedef void (*TileFn)(
+      const int,
+      const TIn* const,
+      const TIn* const, const int, const int,
+      TOut* const, const int, const int
+    );
+
+    // Determine the maximum padding values which can be applied to tiles of
+    // the tensors involved in this class of convolution.
+    static constexpr int max_in_pad_top = 2;
+    static constexpr int max_in_pad_left = 2;
+    static constexpr int max_in_pad_bottom = inner_tile_rows - 1;
+    static constexpr int max_in_pad_right = inner_tile_cols - 1;
+    static constexpr int max_out_pad_bottom = output_tile_rows;
+    static constexpr int max_out_pad_right = output_tile_cols;
+
+    /** Array of methods to process tensor tiles.
+     *
+     * Allows dynamic dispatch to specialized implementations based on
+     * different padding configurations.
+     */
+    static const TileFn tile_fns[
+      max_in_pad_top][max_in_pad_left][max_in_pad_bottom][max_in_pad_right][
+      max_out_pad_bottom][max_out_pad_right
+    ];
+
+  private:
+    // Member variables of instances of a convolution engine.
+    const TIn* const _weights;
+    const TIn* const _input;
+    TOut* const _output;
+    const int _n_batches, _n_input_rows, _n_input_cols, _n_channels,
+              _n_output_rows, _n_output_cols, _n_tile_rows, _n_tile_cols;
+    const bool _padding_same;
+};
+
+}  // namespace depthwise
diff --git a/arm_compute/core/NEON/kernels/convolution/depthwise/impl_base.hpp b/arm_compute/core/NEON/kernels/convolution/depthwise/impl_base.hpp
new file mode 100644
index 0000000..f9671fc
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/convolution/depthwise/impl_base.hpp
@@ -0,0 +1,348 @@
+/*
+ * Copyright (c) 2018 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+/*
+ * !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+ *
+ *          NOTE: Header to be included by implementation files only.
+ *
+ * !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+ */
+
+#include <algorithm>
+#include "arm_compute/core/NEON/kernels/convolution/depthwise/depthwise.hpp"
+#include "arm_compute/core/NEON/kernels/convolution/common/utils.hpp"
+
+#pragma once
+
+namespace depthwise
+{
+
+template <int OTR, int OTC, int KR, int KC, int SR, int SC, typename TIn, typename TOut>
+int DepthwiseConvolution<OTR, OTC, KR, KC, SR, SC, TIn, TOut>::get_output_size(
+  const int dim_size, const bool same_padding
+)
+{
+  return iceildiv(dim_size - (same_padding ? 0 : (KC - 1)), SR);
+}
+
+
+template <int OTR, int OTC, int KR, int KC, int SR, int SC, typename TIn, typename TOut>
+DepthwiseConvolution<OTR, OTC, KR, KC, SR, SC, TIn, TOut>::DepthwiseConvolution(
+  const int n_batches, const int n_input_rows, const int n_input_cols,
+  const int n_channels, const bool padding_same,
+  const TIn* const weights,
+  const TIn* const input,
+  TOut* const output
+) : _weights(weights), _input(input), _output(output),
+    _n_batches(n_batches),
+    _n_input_rows(n_input_rows),
+    _n_input_cols(n_input_cols),
+    _n_channels(n_channels),
+    _n_output_rows(get_output_size(n_input_rows, padding_same)),
+    _n_output_cols(get_output_size(n_input_cols, padding_same)),
+    _n_tile_rows(iceildiv(_n_output_rows, output_tile_rows)),
+    _n_tile_cols(iceildiv(_n_output_cols, output_tile_cols)),
+    _padding_same(padding_same)
+{
+}
+
+
+template <int OTR, int OTC, int KR, int KC, int SR, int SC, typename TIn, typename TOut>
+unsigned int DepthwiseConvolution<OTR, OTC, KR, KC, SR, SC, TIn, TOut>::get_window() const
+{
+  // TODO Later support parallelisation over tile rows.
+  return 1;  // _n_tile_rows;
+}
+
+
+template <int OTR, int OTC, int KR, int KC, int SR, int SC, typename TIn, typename TOut>
+void DepthwiseConvolution<OTR, OTC, KR, KC, SR, SC, TIn, TOut>::run(
+  const unsigned int start,
+  const unsigned int stop
+)
+{
+  // TODO Later support parallelisation over tile rows.
+  (void) start;
+  (void) stop;
+
+  // Compute input striding
+  const int input_col_stride = _n_channels;
+  const int input_row_stride = _n_input_cols * input_col_stride;
+  const int input_batch_stride = _n_input_rows * input_row_stride;
+
+  // Compute output striding
+  const int output_col_stride = _n_channels;
+  const int output_row_stride = _n_output_cols * output_col_stride;
+  const int output_batch_stride = _n_output_rows * output_row_stride;
+
+  // Compute top and bottom padding for input and output
+  const int input_pad_top = _padding_same ?
+                            ((_n_output_rows - 1)*stride_rows + kernel_rows - _n_input_rows) / 2 : 0;
+  const int input_pad_left = _padding_same ?
+                             ((_n_output_cols - 1)*stride_cols + kernel_cols - _n_input_cols) / 2 : 0;
+  constexpr int tile_overlap = kernel_rows - 1;
+
+  // Perform the convolution by calling `process_tile_row` for each tile row in
+  // each batch.
+  for (int batch = 0; batch < _n_batches; batch++)
+  {
+    const TIn* const inptr_batch = _input + batch*input_batch_stride;
+    TOut* const outptr_batch = _output + batch*output_batch_stride;
+
+    // Loop over rows of tiles
+    for (int tile_i = 0; tile_i < _n_tile_rows; tile_i++)
+    {
+      // Pointer to the row
+      const int input_row_offset = (tile_i == 0) ? 0 : input_pad_top;
+      const TIn* const inptr_row = (inptr_batch + ((inner_tile_rows - tile_overlap)*tile_i - input_row_offset)*input_row_stride);
+      TOut* const outptr_row = outptr_batch + output_tile_rows * tile_i * output_row_stride;
+
+      // Input padding (top + bottom) for the row
+      const int input_row_top = tile_i*(inner_tile_rows - tile_overlap) - input_pad_top;
+      const int input_row_bottom = input_row_top + inner_tile_rows;
+      const int input_row_pad_top = (tile_i == 0) ? input_pad_top : 0;
+      const int input_row_pad_bottom = std::max(0, input_row_bottom - _n_input_rows);
+
+      // Output padding (bottom) for the row
+      const int output_row_bottom = (tile_i + 1)*output_tile_rows;
+      const int output_row_pad_bottom = std::max(0, output_row_bottom - _n_output_rows);
+
+      // Process the row
+      process_tile_row(
+        _n_channels, _weights,
+        inptr_row, input_row_stride, input_col_stride,
+        outptr_row, output_row_stride, output_col_stride,
+        input_row_pad_top, input_pad_left, input_row_pad_bottom,
+        output_row_pad_bottom,
+        _n_tile_cols, _n_input_cols, _n_output_cols
+      );
+    }
+  }
+}
+
+
+template <int OTR, int OTC, int KR, int KC, int SR, int SC, typename TIn, typename TOut>
+void DepthwiseConvolution<OTR, OTC, KR, KC, SR, SC, TIn, TOut>::process_tile_row(
+  const int n_channels,
+  const TIn* const weights,
+  const TIn* const inptr,
+  const int in_row_stride,
+  const int in_col_stride,
+  TOut* const outptr,
+  const int out_row_stride,
+  const int out_col_stride,
+  const int row_pad_in_top,
+  const int row_pad_in_left,
+  const int row_pad_in_bottom,
+  const int row_pad_out_bottom,
+  const int n_tiles,
+  const int n_input_cols,
+  const int n_output_cols
+)
+{
+  constexpr int tile_overlap = kernel_cols - 1;
+
+  // Loop over columns of tiles
+  for (int tile_j = 0; tile_j < n_tiles; tile_j++)
+  {
+    // Input padding (left + right) for the tile
+    const int t_pad_in_left = (tile_j == 0) ? row_pad_in_left : 0;
+    const int t_in_start = tile_j*(inner_tile_cols - tile_overlap) - row_pad_in_left;
+    const int t_in_end = t_in_start + inner_tile_cols;
+    const int t_pad_in_right = std::max(0, t_in_end - n_input_cols);
+
+    // Output padding (right) for the tile
+    const int t_out_end = (tile_j + 1) * output_tile_cols;
+    const int t_pad_out_right = std::max(0, t_out_end - n_output_cols);
+
+    // Get pointers into the inputs and outputs
+    const int col_offset = (tile_j == 0) ? 0 : row_pad_in_left;
+    const TIn* const inptr_col = (inptr + ((inner_tile_cols - tile_overlap)*tile_j - col_offset)*in_col_stride);
+    TOut* const outptr_col = outptr + tile_j * output_tile_cols * out_col_stride;
+
+    // Apply the specific tile processing function
+    tile_fns[row_pad_in_top][t_pad_in_left][row_pad_in_bottom][t_pad_in_right][row_pad_out_bottom][t_pad_out_right](
+      n_channels, weights,
+      inptr_col, in_row_stride, in_col_stride,
+      outptr_col, out_row_stride, out_col_stride
+    );
+  }
+}
+
+
+template <int OTR, int OTC, int KR, int KC, int SR, int SC, typename TIn, typename TOut>
+template <
+  int in_pad_top, int in_pad_left, int in_pad_bottom, int in_pad_right,
+  int out_pad_bottom, int out_pad_right
+>
+void DepthwiseConvolution<OTR, OTC, KR, KC, SR, SC, TIn, TOut>::process_tile(
+  const int n_channels,
+  const TIn* const weights,
+  const TIn* const inptr,
+  const int in_row_stride,
+  const int in_col_stride,
+  TOut* const outptr,
+  const int out_row_stride,
+  const int out_col_stride
+)
+{
+  // Compute valid ranges of the tile
+  constexpr int in_cells_i = inner_tile_rows - in_pad_bottom;
+  constexpr int in_cells_j = inner_tile_cols - in_pad_right;
+  constexpr int out_cells_i = output_tile_rows - out_pad_bottom;
+  constexpr int out_cells_j = output_tile_cols - out_pad_right;
+
+  // Instantiate pointers
+  const TIn* inptr_base = inptr;
+  const TIn* wptr_base = weights;
+  TOut* outptr_base = outptr;
+
+  const int weight_col_stride = n_channels;
+  const int weight_row_stride = kernel_cols * n_channels;
+
+  // Perform the depthwise convolution
+  int channels_remaining = n_channels;
+  for (; channels_remaining; channels_remaining--)
+  {
+    // Load input tile
+    TIn u[inner_tile_rows][inner_tile_cols];
+    for (int i = 0; i < inner_tile_rows; i++)
+    {
+      const TIn* const inptr_row = inptr_base + (i - in_pad_top)*in_row_stride;
+      for (int j = 0; j < inner_tile_cols; j++)
+      {
+        if (i < in_pad_top || in_cells_i <= i ||
+            j < in_pad_left || in_cells_j <= j)
+        {
+          u[i][j] = static_cast<TIn>(0);
+        }
+        else
+        {
+          u[i][j] = *(inptr_row + (j - in_pad_left)*in_col_stride);
+        }
+      }
+    }
+    inptr_base++;
+
+    // Load weights tile
+    TIn w[kernel_rows][kernel_cols];
+    for (int i = 0; i < kernel_rows; i++)
+    {
+      const TIn* const wptr_row = wptr_base + i*weight_row_stride;
+      for (int j = 0; j < kernel_cols; j++)
+      {
+        w[i][j] = *(wptr_row + j*weight_col_stride);
+      }
+    }
+    wptr_base++;
+
+    // Perform the convolution
+    TOut v[out_cells_i][out_cells_j];
+    for (int out_i = 0; out_i < out_cells_i; out_i++)
+    {
+      for (int out_j = 0; out_j < out_cells_j; out_j++)
+      {
+        // Clear the accumulator
+        v[out_i][out_j] = static_cast<TOut>(0);
+
+        // Base co-ordinate
+        const int base_i = out_i * stride_rows;
+        const int base_j = out_j * stride_cols;
+
+        // Fill the accumulator
+        for (int in_i = 0; in_i < kernel_rows; in_i++)
+        {
+          const int i = base_i + in_i;
+          for (int in_j = 0; in_j < kernel_cols; in_j++)
+          {
+            const int j = base_j + in_j;
+            v[out_i][out_j] += w[in_i][in_j] * u[i][j];
+          }
+        }
+      }
+    }
+
+    // Store the output tile
+    for (int i = 0; i < out_cells_i; i++)
+    {
+      TOut* const outptr_row = outptr_base + i*out_row_stride;
+      for (int j = 0; j < out_cells_j; j++)
+      {
+        *(outptr_row + j*out_col_stride) = v[i][j];
+      }
+    }
+    outptr_base++;
+  }
+}
+
+
+// New templated struct used solely as a way to provide tile processing
+// specialisations.
+template <int OutputTileRows, int OutputTileCols,
+          int KernelRows, int KernelCols,
+          int StrideRows, int StrideCols,
+          typename TIn, typename TOut>
+struct DepthwiseConvolutionImpl : public DepthwiseConvolution<
+    OutputTileRows, OutputTileCols,
+    KernelRows, KernelCols,
+    StrideRows, StrideCols, TIn, TOut
+>
+{
+  template <
+    int in_pad_top, int in_pad_left, int in_pad_bottom, int in_pad_right,
+    int out_pad_bottom, int out_pad_right
+  >
+  static void process_tile(
+    const int n_channels,
+    const TIn* const weights,
+    const TIn* const inptr,
+    const int in_row_stride,
+    const int in_col_stride,
+    TOut* const outptr,
+    const int out_row_stride,
+    const int out_col_stride
+  )
+  {
+    // By default, redirect to parent. Specialised implementations can be added
+    // by overriding this method.
+    DepthwiseConvolution<OutputTileRows, OutputTileCols,
+                         KernelRows, KernelCols,
+                         StrideRows, StrideCols,
+                         TIn, TOut>::
+      template process_tile<in_pad_top, in_pad_left, in_pad_bottom, in_pad_right,
+                            out_pad_bottom, out_pad_right>(
+        n_channels,
+        weights,
+        inptr,
+        in_row_stride,
+        in_col_stride,
+        outptr,
+        out_row_stride,
+        out_col_stride
+    );
+  }
+};
+
+}  // namespace depthwise
diff --git a/arm_compute/core/NEON/kernels/convolution/depthwise/impl_fp32_fp32.hpp b/arm_compute/core/NEON/kernels/convolution/depthwise/impl_fp32_fp32.hpp
new file mode 100644
index 0000000..e7f0609
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/convolution/depthwise/impl_fp32_fp32.hpp
@@ -0,0 +1,263 @@
+/*
+ * Copyright (c) 2018 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+/*
+ * !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+ *
+ *          NOTE: Header to be included by implementation files only.
+ *
+ * !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+ */
+
+#include "arm_compute/core/NEON/kernels/convolution/common/arm.hpp"
+#include "arm_compute/core/NEON/kernels/convolution/depthwise/impl_base.hpp"
+
+#pragma once
+
+namespace depthwise
+{
+// Partial specialisation for FP32 to FP32
+template <int OutputTileRows, int OutputTileCols,
+          int KernelRows, int KernelCols,
+          int StrideRows, int StrideCols>
+struct DepthwiseConvolutionImpl<OutputTileRows, OutputTileCols, KernelRows, KernelCols, StrideRows, StrideCols, float, float>
+{
+  typedef DepthwiseConvolution<
+    OutputTileRows, OutputTileCols,
+    KernelRows, KernelCols,
+    StrideRows, StrideCols,
+    float, float
+  > DWC;
+
+  template <
+    int in_pad_top, int in_pad_left, int in_pad_bottom, int in_pad_right,
+    int out_pad_bottom, int out_pad_right
+  >
+  static void process_tile(
+    const int n_channels,
+    const float* const weights,
+    const float* const inptr,
+    const int in_row_stride,
+    const int in_col_stride,
+    float* const outptr,
+    const int out_row_stride,
+    const int out_col_stride
+  );
+};
+
+
+template <int OTR, int OTC, int KR, int KC, int SR, int SC>
+template <
+  int in_pad_top, int in_pad_left, int in_pad_bottom, int in_pad_right,
+  int out_pad_bottom, int out_pad_right
+>
+void DepthwiseConvolutionImpl<OTR, OTC, KR, KC, SR, SC, float, float>::process_tile(
+  const int n_channels,
+  const float* const weights,
+  const float* const inptr,
+  const int in_row_stride,
+  const int in_col_stride,
+  float* const outptr,
+  const int out_row_stride,
+  const int out_col_stride
+)
+{
+  constexpr auto inner_tile_rows = DWC::inner_tile_rows;
+  constexpr auto inner_tile_cols = DWC::inner_tile_cols;
+  constexpr auto kernel_rows = DWC::kernel_rows;
+  constexpr auto kernel_cols = DWC::kernel_cols;
+  constexpr auto output_tile_rows = DWC::output_tile_rows;
+  constexpr auto output_tile_cols = DWC::output_tile_cols;
+  constexpr auto stride_rows = DWC::stride_rows;
+  constexpr auto stride_cols = DWC::stride_cols;
+
+  // Compute valid ranges of the tile
+  constexpr int in_cells_i = inner_tile_rows - in_pad_bottom;
+  constexpr int in_cells_j = inner_tile_cols - in_pad_right;
+  constexpr int out_cells_i = output_tile_rows - out_pad_bottom;
+  constexpr int out_cells_j = output_tile_cols - out_pad_right;
+
+  // Instantiate pointers
+  const float* inptr_base = inptr;
+  const float* wptr_base = weights;
+  float* outptr_base = outptr;
+
+  const int weight_col_stride = n_channels;
+  const int weight_row_stride = kernel_cols * n_channels;
+
+  // Perform the depthwise convolution
+  int channels_remaining = n_channels;
+#ifdef __aarch64__
+  for (; channels_remaining >= 4; channels_remaining -= 4)
+  {
+    // Load input tile
+    float32x4_t u[inner_tile_rows][inner_tile_cols];
+    for (int i = 0; i < inner_tile_rows; i++)
+    {
+      const float* const inptr_row = inptr_base + (i - in_pad_top)*in_row_stride;
+      for (int j = 0; j < inner_tile_cols; j++)
+      {
+        if (i < in_pad_top || in_cells_i <= i ||
+            j < in_pad_left || in_cells_j <= j)
+        {
+          u[i][j] = vdupq_n_f32(0.0f);
+        }
+        else
+        {
+          u[i][j] = vld1q_f32(inptr_row + (j - in_pad_left)*in_col_stride);
+        }
+      }
+    }
+    inptr_base += 4;
+
+    // Load weights tile
+    float32x4_t w[kernel_rows][kernel_cols];
+    for (int i = 0; i < kernel_rows; i++)
+    {
+      const float* const wptr_row = wptr_base + i*weight_row_stride;
+      for (int j = 0; j < kernel_cols; j++)
+      {
+        w[i][j] = vld1q_f32(wptr_row + j*weight_col_stride);
+      }
+    }
+    wptr_base += 4;
+
+    // Perform the convolution
+    float32x4_t v[out_cells_i][out_cells_j];
+    for (int out_i = 0; out_i < out_cells_i; out_i++)
+    {
+      for (int out_j = 0; out_j < out_cells_j; out_j++)
+      {
+        // Base co-ordinate
+        const int base_i = out_i * stride_rows;
+        const int base_j = out_j * stride_cols;
+
+        // Fill the accumulator
+        for (int in_i = 0; in_i < kernel_rows; in_i++)
+        {
+          const int i = base_i + in_i;
+          for (int in_j = 0; in_j < kernel_cols; in_j++)
+          {
+            const int j = base_j + in_j;
+            if (in_i == 0 && in_j == 0)
+            {
+              // v[out_i][out_j] = w[in_i][in_j] * u[i][j];
+              v[out_i][out_j] = vmulq_f32(w[in_i][in_j], u[i][j]);
+            }
+            else
+            {
+              // v[out_i][out_j] += w[in_i][in_j] * u[i][j];
+              v[out_i][out_j] = vmlaq_f32(v[out_i][out_j], w[in_i][in_j], u[i][j]);
+            }
+          }
+        }
+      }
+    }
+
+    // Store the output tile
+    for (int i = 0; i < out_cells_i; i++)
+    {
+      float* const outptr_row = outptr_base + i*out_row_stride;
+      for (int j = 0; j < out_cells_j; j++)
+      {
+        vst1q_f32(outptr_row + j*out_col_stride, v[i][j]);
+      }
+    }
+    outptr_base += 4;
+  }
+#endif  // __aarch64__
+  for (; channels_remaining; channels_remaining--)
+  {
+    // Load input tile
+    float u[inner_tile_rows][inner_tile_cols];
+    for (int i = 0; i < inner_tile_rows; i++)
+    {
+      const float* const inptr_row = inptr_base + (i - in_pad_top)*in_row_stride;
+      for (int j = 0; j < inner_tile_cols; j++)
+      {
+        if (i < in_pad_top || in_cells_i <= i ||
+            j < in_pad_left || in_cells_j <= j)
+        {
+          u[i][j] = static_cast<float>(0);
+        }
+        else
+        {
+          u[i][j] = *(inptr_row + (j - in_pad_left)*in_col_stride);
+        }
+      }
+    }
+    inptr_base++;
+
+    // Load weights tile
+    float w[kernel_rows][kernel_cols];
+    for (int i = 0; i < kernel_rows; i++)
+    {
+      const float* const wptr_row = wptr_base + i*weight_row_stride;
+      for (int j = 0; j < kernel_cols; j++)
+      {
+        w[i][j] = *(wptr_row + j*weight_col_stride);
+      }
+    }
+    wptr_base++;
+
+    // Perform the convolution
+    float v[out_cells_i][out_cells_j];
+    for (int out_i = 0; out_i < out_cells_i; out_i++)
+    {
+      for (int out_j = 0; out_j < out_cells_j; out_j++)
+      {
+        // Clear the accumulator
+        v[out_i][out_j] = static_cast<float>(0);
+
+        // Base co-ordinate
+        const int base_i = out_i * stride_rows;
+        const int base_j = out_j * stride_cols;
+
+        // Fill the accumulator
+        for (int in_i = 0; in_i < kernel_rows; in_i++)
+        {
+          const int i = base_i + in_i;
+          for (int in_j = 0; in_j < kernel_cols; in_j++)
+          {
+            const int j = base_j + in_j;
+            v[out_i][out_j] += w[in_i][in_j] * u[i][j];
+          }
+        }
+      }
+    }
+
+    // Store the output tile
+    for (int i = 0; i < out_cells_i; i++)
+    {
+      float* const outptr_row = outptr_base + i*out_row_stride;
+      for (int j = 0; j < out_cells_j; j++)
+      {
+        *(outptr_row + j*out_col_stride) = v[i][j];
+      }
+    }
+    outptr_base++;
+  }
+}
+
+}  // namespace depthwise
diff --git a/arm_compute/core/NEON/kernels/winograd/batched_blocked_gemm.hpp b/arm_compute/core/NEON/kernels/convolution/winograd/batched_blocked_gemm.hpp
similarity index 100%
rename from arm_compute/core/NEON/kernels/winograd/batched_blocked_gemm.hpp
rename to arm_compute/core/NEON/kernels/convolution/winograd/batched_blocked_gemm.hpp
diff --git a/arm_compute/core/NEON/kernels/winograd/gemm.hpp b/arm_compute/core/NEON/kernels/convolution/winograd/gemm.hpp
similarity index 98%
rename from arm_compute/core/NEON/kernels/winograd/gemm.hpp
rename to arm_compute/core/NEON/kernels/convolution/winograd/gemm.hpp
index e48d31b..62a20c9 100644
--- a/arm_compute/core/NEON/kernels/winograd/gemm.hpp
+++ b/arm_compute/core/NEON/kernels/convolution/winograd/gemm.hpp
@@ -23,7 +23,7 @@
  */
 
 #pragma once
-#include "utils.hpp"
+#include "arm_compute/core/NEON/kernels/convolution/common/utils.hpp"
 
 template <typename TIn, typename TOut>
 inline void Gemm(const TIn* const a, const TIn* const b, TOut *c,
diff --git a/arm_compute/core/NEON/kernels/winograd/gemm/a64_sgemm.hpp b/arm_compute/core/NEON/kernels/convolution/winograd/gemm/a64_sgemm.hpp
similarity index 99%
rename from arm_compute/core/NEON/kernels/winograd/gemm/a64_sgemm.hpp
rename to arm_compute/core/NEON/kernels/convolution/winograd/gemm/a64_sgemm.hpp
index caeb48f..8073cb1 100644
--- a/arm_compute/core/NEON/kernels/winograd/gemm/a64_sgemm.hpp
+++ b/arm_compute/core/NEON/kernels/convolution/winograd/gemm/a64_sgemm.hpp
@@ -24,7 +24,7 @@
 
 #pragma once
 #include <cassert>
-#include "../utils.hpp"
+#include "arm_compute/core/NEON/kernels/convolution/common/utils.hpp"
 
 #ifdef __aarch64__
 
diff --git a/arm_compute/core/NEON/kernels/winograd/gemm/a64_sgemm_4x16.hpp b/arm_compute/core/NEON/kernels/convolution/winograd/gemm/a64_sgemm_4x16.hpp
similarity index 100%
rename from arm_compute/core/NEON/kernels/winograd/gemm/a64_sgemm_4x16.hpp
rename to arm_compute/core/NEON/kernels/convolution/winograd/gemm/a64_sgemm_4x16.hpp
diff --git a/arm_compute/core/NEON/kernels/winograd/transforms/input.hpp b/arm_compute/core/NEON/kernels/convolution/winograd/transforms/input.hpp
similarity index 98%
rename from arm_compute/core/NEON/kernels/winograd/transforms/input.hpp
rename to arm_compute/core/NEON/kernels/convolution/winograd/transforms/input.hpp
index 075765a..6dd8f54 100644
--- a/arm_compute/core/NEON/kernels/winograd/transforms/input.hpp
+++ b/arm_compute/core/NEON/kernels/convolution/winograd/transforms/input.hpp
@@ -23,7 +23,7 @@
  */
 
 #pragma once
-#include "../winograd_gemm.hpp"
+#include "arm_compute/core/NEON/kernels/convolution/winograd/winograd_gemm.hpp"
 
 namespace winograd
 {
diff --git a/arm_compute/core/NEON/kernels/winograd/transforms/kernel.hpp b/arm_compute/core/NEON/kernels/convolution/winograd/transforms/kernel.hpp
similarity index 96%
rename from arm_compute/core/NEON/kernels/winograd/transforms/kernel.hpp
rename to arm_compute/core/NEON/kernels/convolution/winograd/transforms/kernel.hpp
index 4b54dfd..bad3ef2 100644
--- a/arm_compute/core/NEON/kernels/winograd/transforms/kernel.hpp
+++ b/arm_compute/core/NEON/kernels/convolution/winograd/transforms/kernel.hpp
@@ -22,7 +22,7 @@
  * SOFTWARE.
  */
 
-#include "winograd_gemm.hpp"
+#include "arm_compute/core/NEON/kernels/convolution/winograd/winograd_gemm.hpp"
 using namespace winograd;
 
 
diff --git a/arm_compute/core/NEON/kernels/winograd/transforms/output.hpp b/arm_compute/core/NEON/kernels/convolution/winograd/transforms/output.hpp
similarity index 98%
rename from arm_compute/core/NEON/kernels/winograd/transforms/output.hpp
rename to arm_compute/core/NEON/kernels/convolution/winograd/transforms/output.hpp
index 0dd7197..401b281 100644
--- a/arm_compute/core/NEON/kernels/winograd/transforms/output.hpp
+++ b/arm_compute/core/NEON/kernels/convolution/winograd/transforms/output.hpp
@@ -23,7 +23,7 @@
  */
 
 #pragma once
-#include "../winograd_gemm.hpp"
+#include "arm_compute/core/NEON/kernels/convolution/winograd/winograd_gemm.hpp"
 
 namespace winograd
 {
diff --git a/arm_compute/core/NEON/kernels/winograd/winograd_gemm.hpp b/arm_compute/core/NEON/kernels/convolution/winograd/winograd_gemm.hpp
similarity index 97%
rename from arm_compute/core/NEON/kernels/winograd/winograd_gemm.hpp
rename to arm_compute/core/NEON/kernels/convolution/winograd/winograd_gemm.hpp
index 2ea70f1..f3b2bb1 100644
--- a/arm_compute/core/NEON/kernels/winograd/winograd_gemm.hpp
+++ b/arm_compute/core/NEON/kernels/convolution/winograd/winograd_gemm.hpp
@@ -24,13 +24,13 @@
 
 #pragma once
 
-#include "alloc.hpp"
-#include "convolution.hpp"
+#include "arm_compute/core/NEON/kernels/convolution/common/alloc.hpp"
+#include "arm_compute/core/NEON/kernels/convolution/common/convolution.hpp"
 #include "gemm.hpp"
-#include "profiler.hpp"
-#include "shims.hpp"
-#include "tensor.hpp"
-#include "utils.hpp"
+#include "arm_compute/core/NEON/kernels/convolution/common/profiler.hpp"
+#include "arm_compute/core/NEON/kernels/convolution/common/shims.hpp"
+#include "arm_compute/core/NEON/kernels/convolution/common/tensor.hpp"
+#include "arm_compute/core/NEON/kernels/convolution/common/utils.hpp"
 
 #include <thread>
 #include <utility>
diff --git a/arm_compute/core/NEON/kernels/convolution/NEDirectConvolution3x3.h b/arm_compute/core/NEON/kernels/detail/NEDirectConvolution3x3.h
similarity index 99%
rename from arm_compute/core/NEON/kernels/convolution/NEDirectConvolution3x3.h
rename to arm_compute/core/NEON/kernels/detail/NEDirectConvolution3x3.h
index 7f39e5e..fee2066 100644
--- a/arm_compute/core/NEON/kernels/convolution/NEDirectConvolution3x3.h
+++ b/arm_compute/core/NEON/kernels/detail/NEDirectConvolution3x3.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
  *
  * SPDX-License-Identifier: MIT
  *
diff --git a/arm_compute/core/NEON/kernels/convolution/NEDirectConvolutionDetail.h b/arm_compute/core/NEON/kernels/detail/NEDirectConvolutionDetail.h
similarity index 100%
rename from arm_compute/core/NEON/kernels/convolution/NEDirectConvolutionDetail.h
rename to arm_compute/core/NEON/kernels/detail/NEDirectConvolutionDetail.h
diff --git a/arm_compute/core/NEON/kernels/winograd/direct_convolution.hpp b/arm_compute/core/NEON/kernels/winograd/direct_convolution.hpp
deleted file mode 100644
index 6a9984a..0000000
--- a/arm_compute/core/NEON/kernels/winograd/direct_convolution.hpp
+++ /dev/null
@@ -1,35 +0,0 @@
-/*
- * Copyright (c) 2017 ARM Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-
-#pragma once
-#include "convolution.hpp"
-#include "tensor.hpp"
-
-void direct_convolution(
-  const Tensor4D<Tensor4DShape, float>& input,
-  const Tensor4D<KernelShape, float>& kernel,
-  const Tensor4D<Tensor4DShape, float>& biases,
-  Tensor4D<Tensor4DShape, float>& output,
-  const PaddingType padding
-);
diff --git a/arm_compute/core/Types.h b/arm_compute/core/Types.h
index 72be5cb..5a08ac9 100644
--- a/arm_compute/core/Types.h
+++ b/arm_compute/core/Types.h
@@ -106,6 +106,13 @@
 /* Constant value used to indicate a ORB scaled pyramid */
 constexpr float SCALE_PYRAMID_ORB = 8.408964152537146130583778358414e-01;
 
+/** Supported tensor data layouts */
+enum class DataLayout
+{
+    NCHW,
+    NHWC
+};
+
 /** Quantization settings (used for QASYMM8 data type) */
 struct QuantizationInfo
 {
diff --git a/arm_compute/core/Utils.h b/arm_compute/core/Utils.h
index fc89d97..111eac0 100644
--- a/arm_compute/core/Utils.h
+++ b/arm_compute/core/Utils.h
@@ -602,6 +602,16 @@
     }
 }
 
+/** Calculate padding requirements in case of SAME padding
+ *
+ * @param[in] input_shape   Input shape
+ * @param[in] weights_shape Weights shape
+ * @param[in] conv_info     Convolution information (containing strides)
+ *
+ * @return PadStrideInfo for SAME padding
+ */
+PadStrideInfo calculate_same_pad(TensorShape input_shape, TensorShape weights_shape, PadStrideInfo conv_info);
+
 /** Returns expected shape for the deconvolution output tensor.
  *
  * @param[in] out_dims widht and height of the output tensor, these values can be obtained with the function deconvolution_output_dimensions.