DirectConv3d support refine

- Decouple data support of CpuDirectConv3dKernel
- Update documentation for Conv3d

Signed-off-by: Sheri Zhang <sheri.zhang@arm.com>
Change-Id: I1d94aa28f821f45a1a3d39cc3335c8faeee89f0d
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/6453
Reviewed-by: Giorgio Arena <giorgio.arena@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
diff --git a/arm_compute/runtime/NEON/functions/NEConv3D.h b/arm_compute/runtime/NEON/functions/NEConv3D.h
index 487d357..2b3a45f 100644
--- a/arm_compute/runtime/NEON/functions/NEConv3D.h
+++ b/arm_compute/runtime/NEON/functions/NEConv3D.h
@@ -29,7 +29,6 @@
 #include "arm_compute/core/ITensorInfo.h"
 #include "arm_compute/core/Types.h"
 #include "arm_compute/runtime/FunctionDescriptors.h"
-#include "arm_compute/runtime/MemoryGroup.h"
 
 #include <memory>
 
diff --git a/docs/user_guide/data_layout.dox b/docs/user_guide/data_layout.dox
index 97d3ea6..ae69bbf 100644
--- a/docs/user_guide/data_layout.dox
+++ b/docs/user_guide/data_layout.dox
@@ -34,8 +34,9 @@
 
 - NHWC: The native layout of Compute Library that delivers the best performance where channels are in the fastest changing dimension
 - NCHW: Legacy layout where width is in the fastest changing dimension
+- NDHWC: New data layout for supporting 3D operators
 
-, where N = batch, C = channel, H = height, W = width.
+, where N = batch, C = channel, H = height, W = width, D = depth.
 
 */
 } // namespace
diff --git a/docs/user_guide/operator_list.dox b/docs/user_guide/operator_list.dox
index ebc970d..1d06a39 100644
--- a/docs/user_guide/operator_list.dox
+++ b/docs/user_guide/operator_list.dox
@@ -52,9 +52,10 @@
   <ul>
     <li>NHWC: The native layout of Compute Library that delivers the best performance where channels are in the fastest changing dimension
     <li>NCHW: Legacy layout where width is in the fastest changing dimension
+    <li>NDHWC: New data layout for supporting 3D operators
     <li>All: Agnostic to any specific data layout
   </ul>
-where N = batches, C = channels, H = height, W = width
+where N = batches, C = channels, H = height, W = width, D = depth
 
 <table>
 <caption id="multi_row"></caption>
diff --git a/docs/user_guide/release_version_and_change_log.dox b/docs/user_guide/release_version_and_change_log.dox
index 583cf4f..2470b45 100644
--- a/docs/user_guide/release_version_and_change_log.dox
+++ b/docs/user_guide/release_version_and_change_log.dox
@@ -40,6 +40,14 @@
 
 @section S2_2_changelog Changelog
 
+v21.11 Public major release
+ - Various bug fixes.
+ - Various optimizations.
+ - New OpenCL kernels / functions:
+   - @ref CLConv3D
+ - New Arm® Neon™ kernels / functions:
+   - @ref NEConv3D
+
 v21.08 Public major release
  - Various bug fixes.
  - Various optimizations:
diff --git a/src/core/NEON/NEMath.h b/src/core/NEON/NEMath.h
index 13484c9..8118c47 100644
--- a/src/core/NEON/NEMath.h
+++ b/src/core/NEON/NEMath.h
@@ -239,6 +239,14 @@
  */
 float32x2_t vsin_f32(float32x2_t val);
 
+/** Reduce a vector to be a scalar by accumulating all lanes in the vector
+ *
+ * @param[in] v Vector to be reduced.
+ *
+ * @return the wrapped-around number.
+ */
+float vreduce(const float32x4_t &v);
+
 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
 /** Calculate hyperbolic tangent.
  *
@@ -319,6 +327,13 @@
  */
 float16x8_t vsinq_f16(float16x8_t val);
 
+/** Reduce a vector to be a scalar by accumulating all lanes in the vector
+ *
+ * @param[in] v Vector to be reduced.
+ *
+ * @return the wrapped-around number.
+ */
+float16_t vreduce(const float16x8_t &v);
 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
 } // namespace arm_compute
 #include "src/core/NEON/NEMath.inl"
diff --git a/src/core/NEON/NEMath.inl b/src/core/NEON/NEMath.inl
index 5ac62ba..05cf301 100644
--- a/src/core/NEON/NEMath.inl
+++ b/src/core/NEON/NEMath.inl
@@ -193,7 +193,7 @@
     static const float32x4_t CONST_THR      = vdupq_n_f32(5.e-3);
     static const float32x4_t CONST_1_3      = vdupq_n_f32(0.3333333f);
 
-    float32x4_t x     = vminq_f32(vmaxq_f32(val, CONST_MIN_TANH), CONST_MAX_TANH);
+    float32x4_t x = vminq_f32(vmaxq_f32(val, CONST_MIN_TANH), CONST_MAX_TANH);
     // x * (1 - x^2/3) if |x| < 5.e-3 or (exp2x - 1) / (exp2x + 1) otherwise
     float32x4_t exp2x = vbslq_f32(vcgtq_f32(vabsq_f32(x), CONST_THR), vexpq_f32(vmulq_f32(CONST_2, x)), vmulq_f32(x, x));
     float32x4_t num   = vbslq_f32(vcgtq_f32(vabsq_f32(x), CONST_THR), vsubq_f32(exp2x, CONST_1), vmulq_f32(CONST_1_3, exp2x));
@@ -418,6 +418,18 @@
     return convert_int8x16_to_float32x4x4(in);
 }
 
+inline float vreduce(const float32x4_t &v)
+{
+    const float32x2_t v0    = vget_high_f32(v);
+    const float32x2_t v1    = vget_low_f32(v);
+    const float32x2_t v_out = vadd_f32(v0, v1);
+
+    const float a = vget_lane_f32(v_out, 0);
+    const float b = vget_lane_f32(v_out, 1);
+
+    return a + b;
+}
+
 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
 /** Exponent polynomial coefficients */
 /** Logarithm polynomial coefficients */
@@ -550,6 +562,19 @@
     return vcvt_f16_f32(vcombine_f32(res_low, res_high));
 }
 
+inline float16_t vreduce(const float16x8_t &v)
+{
+    const float16x4_t v0    = vget_high_f16(v);
+    const float16x4_t v1    = vget_low_f16(v);
+    const float16x4_t v_out = vadd_f16(v0, v1);
+
+    const float16_t a = vget_lane_f16(v_out, 0);
+    const float16_t b = vget_lane_f16(v_out, 1);
+    const float16_t c = vget_lane_f16(v_out, 2);
+    const float16_t d = vget_lane_f16(v_out, 3);
+
+    return a + b + c + d;
+}
 #endif /* DOXYGEN_SKIP_THIS */
 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
 } // namespace arm_compute
diff --git a/src/cpu/kernels/CpuDirectConv2dKernel.cpp b/src/cpu/kernels/CpuDirectConv2dKernel.cpp
index db1b5f3..68de980 100644
--- a/src/cpu/kernels/CpuDirectConv2dKernel.cpp
+++ b/src/cpu/kernels/CpuDirectConv2dKernel.cpp
@@ -711,17 +711,6 @@
     }
 };
 
-float vreduce(const float32x4_t &v)
-{
-    auto v0    = wrapper::vgethigh(v);
-    auto v1    = wrapper::vgetlow(v);
-    auto v_out = wrapper::vadd(v0, v1);
-
-    float a = wrapper::vgetlane(v_out, 0);
-    float b = wrapper::vgetlane(v_out, 1);
-    return a + b;
-}
-
 template <typename T1, typename T2>
 inline void convolve_1x1(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
                          const ITensor *src, const ITensor *weights, ITensor *dst, const PadStrideInfo &conv_info)
diff --git a/src/cpu/kernels/CpuDirectConv3dKernel.cpp b/src/cpu/kernels/CpuDirectConv3dKernel.cpp
index fecdb2b..595b5f1 100644
--- a/src/cpu/kernels/CpuDirectConv3dKernel.cpp
+++ b/src/cpu/kernels/CpuDirectConv3dKernel.cpp
@@ -23,9 +23,6 @@
  */
 #include "src/cpu/kernels/CpuDirectConv3dKernel.h"
 
-#include "src/core/NEON/kernels/detail/NEDirectConvolutionDetail.h"
-#include "src/core/NEON/wrapper/wrapper.h"
-
 #include "arm_compute/core/Error.h"
 #include "arm_compute/core/Helpers.h"
 #include "arm_compute/core/IAccessWindow.h"
@@ -35,8 +32,10 @@
 #include "arm_compute/core/Validate.h"
 #include "arm_compute/core/utils/misc/ShapeCalculator.h"
 #include "src/core/CPP/Validate.h"
+#include "src/core/NEON/wrapper/wrapper.h"
+#include "src/core/common/Registrars.h"
 #include "src/core/helpers/AutoConfiguration.h"
-#include "src/core/helpers/WindowHelpers.h"
+#include "src/cpu/kernels/conv3d/neon/list.h"
 
 #include <algorithm>
 
@@ -50,35 +49,86 @@
 {
 namespace
 {
-Status validate_arguments(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const Conv3dInfo &conv_info)
+struct DirectConv3dSelectorData
 {
-    ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src, weights, dst);
-    ARM_COMPUTE_RETURN_ERROR_ON(src->data_layout() != DataLayout::NDHWC);
-    ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(src);
-    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, DataType::F16, DataType::F32);
-    ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, weights);
+    DataType       dt;
+    const CPUInfo &ci;
+};
+using DirectConv3dSelectorPtr = std::add_pointer<bool(const DirectConv3dSelectorData &data)>::type;
+using DirectConv3dKernelPtr   = std::add_pointer<void(const ITensor *, const ITensor *, const ITensor *, ITensor *, const Conv3dInfo &, const Window &)>::type;
+struct DirectConv3dKernel
+{
+    const char                   *name;
+    const DirectConv3dSelectorPtr is_selected;
+    DirectConv3dKernelPtr         ukernel;
+};
 
-    const DataLayout data_layout = src->data_layout();
+static const DirectConv3dKernel available_kernels[] =
+{
+#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
+    {
+        "neon_fp16_directconv3d",
+        [](const DirectConv3dSelectorData & data) { return data.dt == DataType::F16 && data.ci.has_fp16(); },
+        REGISTER_FP16_NEON(arm_compute::cpu::directconv3d_float_neon_ndhwc<float16_t>)
+    },
+#endif /* !defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) */
+    {
+        "neon_fp32_directconv3d",
+        [](const DirectConv3dSelectorData & data) { return data.dt == DataType::F32; },
+        REGISTER_FP32_NEON(arm_compute::cpu::directconv3d_float_neon_ndhwc<float>)
+    }
+};
+
+/** Micro-kernel selector
+ *
+ * @param[in] data Selection data passed to help pick the appropriate micro-kernel
+ *
+ * @return A matching micro-kernel else nullptr
+ */
+const DirectConv3dKernel *get_implementation(const DirectConv3dSelectorData &data)
+{
+    for(const auto &uk : available_kernels)
+    {
+        if(uk.is_selected(data))
+        {
+            return &uk;
+        }
+    }
+    return nullptr;
+}
+
+Status validate_arguments(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, const Conv3dInfo &conv_info)
+{
+    const auto *uk = get_implementation(DirectConv3dSelectorData{ src0->data_type(), CPUInfo::get() });
+    ARM_COMPUTE_RETURN_ERROR_ON(uk == nullptr || uk->ukernel == nullptr);
+
+    ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src0, src1, dst);
+    ARM_COMPUTE_RETURN_ERROR_ON(src0->data_layout() != DataLayout::NDHWC);
+    ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(src0);
+    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src0, 1, DataType::F16, DataType::F32);
+    ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src0, src1);
+
+    const DataLayout data_layout = src0->data_layout();
     const int        channel_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL);
 
     // Weight layout is D, H, W, Cin, Cout
-    ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 5);
-    ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(1) != src->dimension(channel_idx));
+    ARM_COMPUTE_RETURN_ERROR_ON(src1->num_dimensions() > 5);
+    ARM_COMPUTE_RETURN_ERROR_ON(src1->dimension(1) != src0->dimension(channel_idx));
 
-    if(biases != nullptr)
+    if(src2 != nullptr)
     {
-        ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(weights, biases);
-        ARM_COMPUTE_RETURN_ERROR_ON_MSG(biases->dimension(0) != weights->dimension(0),
+        ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src1, src2);
+        ARM_COMPUTE_RETURN_ERROR_ON_MSG(src2->dimension(0) != src1->dimension(0),
                                         "biases size and number of output feature maps should match");
-        ARM_COMPUTE_RETURN_ERROR_ON_MSG(biases->num_dimensions() > 1, "biases should be one dimensional");
+        ARM_COMPUTE_RETURN_ERROR_ON_MSG(src2->num_dimensions() > 1, "biases should be one dimensional");
     }
 
     // Checks performed when output is configured
     if(dst->total_size() != 0)
     {
-        TensorShape output_shape = misc::shape_calculator::compute_conv3d_shape(src->tensor_shape(), weights->tensor_shape(), conv_info);
+        TensorShape output_shape = misc::shape_calculator::compute_conv3d_shape(src0->tensor_shape(), src1->tensor_shape(), conv_info);
 
-        DataType data_type = src->data_type();
+        DataType data_type = src0->data_type();
 
         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(dst->tensor_shape(), output_shape);
         ARM_COMPUTE_RETURN_ERROR_ON(dst->data_type() != data_type);
@@ -86,200 +136,39 @@
 
     return Status{};
 }
-
-/** Reduce a vector to be a scalar by accumulating all lanes in the vector
- *
- * @param[in] v Vector to be reduced.
- *
- * @return the wrapped-around number.
- */
-auto vreduce(const float32x4_t &v)
-{
-    auto v0    = wrapper::vgethigh(v);
-    auto v1    = wrapper::vgetlow(v);
-    auto v_out = wrapper::vadd(v0, v1);
-
-    float a = wrapper::vgetlane(v_out, 0);
-    float b = wrapper::vgetlane(v_out, 1);
-    return a + b;
 }
 
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-auto vreduce(const float16x8_t &v)
+void CpuDirectConv3dKernel::configure(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, ITensorInfo *dst, const Conv3dInfo &conv_info)
 {
-    auto v0    = wrapper::vgethigh(v);
-    auto v1    = wrapper::vgetlow(v);
-    auto v_out = wrapper::vadd(v0, v1);
+    ARM_COMPUTE_UNUSED(src2);
+    ARM_COMPUTE_ERROR_ON_NULLPTR(src0, src1, dst);
 
-    float16_t a = wrapper::vgetlane(v_out, 0);
-    float16_t b = wrapper::vgetlane(v_out, 1);
-    float16_t c = wrapper::vgetlane(v_out, 2);
-    float16_t d = wrapper::vgetlane(v_out, 3);
-    return a + b + c + d;
-}
-#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-}
+    const auto *uk = get_implementation(DirectConv3dSelectorData{ src0->data_type(), CPUInfo::get() });
+    ARM_COMPUTE_ERROR_ON_NULLPTR(uk);
 
-template <typename T>
-void CpuDirectConv3dKernel::convolve_ndhwc(const Window &window, const ITensor *src, const ITensor *weights, const ITensor *biases, ITensor *dst)
-{
-    using vtype                                = wrapper::traits::neon_bitvector<T, wrapper::traits::BitWidth::W128>;
-    using vector_type                          = typename vtype::type;
-    using tag_type                             = typename vtype::tag_type;
-    constexpr int num_elems_read_per_iteration = 16 / sizeof(T);
-
-    // Scalar quantities (N D H W Cin)
-    const int element_size   = src->info()->element_size();
-    const int input_stride_w = src->info()->strides_in_bytes().y() / element_size;
-    const int input_stride_h = src->info()->strides_in_bytes().z() / element_size;
-    const int input_stride_d = src->info()->strides_in_bytes()[3] / element_size;
-    const int input_stride_n = src->info()->strides_in_bytes()[4] / element_size;
-    const int input_dim_w    = src->info()->dimension(1);
-    const int input_dim_h    = src->info()->dimension(2);
-    const int input_dim_d    = src->info()->dimension(3);
-
-    // Kernel info (D H W Cin Cout)
-    const unsigned int kernel_stride_w = weights->info()->strides_in_bytes()[2] / element_size;
-    const unsigned int kernel_stride_h = weights->info()->strides_in_bytes()[3] / element_size;
-    const unsigned int kernel_stride_d = weights->info()->strides_in_bytes()[4] / element_size;
-    const int          kernel_dim_w    = weights->info()->dimension(2);
-    const int          kernel_dim_h    = weights->info()->dimension(3);
-    const int          kernel_dim_d    = weights->info()->dimension(4);
-
-    // Convolution padding and stride
-    const int conv_pad_top   = _conv_info.padding.top;
-    const int conv_pad_left  = _conv_info.padding.left;
-    const int conv_pad_front = _conv_info.padding.front;
-    const int conv_stride_w  = _conv_info.stride.width;
-    const int conv_stride_h  = _conv_info.stride.height;
-    const int conv_stride_d  = _conv_info.stride.depth;
-
-    // Setup input window for the output iterator
-    Window window_out = window;
-    window_out.set(Window::DimX, Window::Dimension(0, 1, 1));
-
-    // Setup input window for the weights iterator
-    Window window_w = calculate_max_window(*weights->info(), Steps());
-    window_w.set(Window::DimY, Window::Dimension(0, 1, 1));
-    window_w.set(Window::DimZ, Window::Dimension(0, 1, 1));
-    window_w.set(Window::DimW, Window::Dimension(0, 1, 1));
-    window_w.set(4, Window::Dimension(0, 1, 1));
-
-    Iterator out(dst, window_out);
-    Iterator wei(weights, window_w);
-
-    const T *biases_ptr = nullptr;
-    if(biases)
-    {
-        biases_ptr = reinterpret_cast<T *>(biases->buffer() + biases->info()->offset_first_element_in_bytes());
-    }
-    execute_window_loop(window_out, [&](const Coordinates & id)
-    {
-        // We are computing the theoretical input starting points
-        const int in_w_start_t = static_cast<int>(id.y()) * conv_stride_w - conv_pad_left;
-        const int in_h_start_t = static_cast<int>(id.z()) * conv_stride_h - conv_pad_top;
-        const int in_d_start_t = static_cast<int>(id[3]) * conv_stride_d - conv_pad_front;
-        const int in_w_end_t   = in_w_start_t + kernel_dim_w;
-        const int in_h_end_t   = in_h_start_t + kernel_dim_h;
-        const int in_d_end_t   = in_d_start_t + kernel_dim_d;
-
-        // We are computing the valid initial and ending input points by checking the borders
-        const int in_w_start = std::max(in_w_start_t, 0);
-        const int in_h_start = std::max(in_h_start_t, 0);
-        const int in_d_start = std::max(in_d_start_t, 0);
-        const int in_w_end   = std::min(in_w_end_t, input_dim_w);
-        const int in_h_end   = std::min(in_h_end_t, input_dim_h);
-        const int in_d_end   = std::min(in_d_end_t, input_dim_d);
-
-        // We use the input points to select the valid weight points to use
-        const int wei_w_start = in_w_start - in_w_start_t;
-        const int wei_h_start = in_h_start - in_h_start_t;
-        const int wei_d_start = in_d_start - in_d_start_t;
-        const int wei_w_end   = kernel_dim_w - (in_w_end_t - in_w_end);
-        const int wei_h_end   = kernel_dim_h - (in_h_end_t - in_h_end);
-        const int wei_d_end   = kernel_dim_d - (in_d_end_t - in_d_end);
-
-        const int      index_c_out_end = weights->info()->dimension(0);
-        const int      index_c_in_end  = weights->info()->dimension(1);
-        const T *const in_ptr_start    = reinterpret_cast<const T *>(src->buffer() + src->info()->offset_first_element_in_bytes()) + id[4] * input_stride_n;
-
-        execute_window_loop(window_w, [&](const Coordinates & id_w)
-        {
-            /*
-            * This is the loop in the weights, and it goes along OFM (output feature map)
-            */
-            const auto weights_ptr_start = reinterpret_cast<const T *>(wei.ptr());
-            T          out_temp          = static_cast<T>(0);
-            T         *out_ptr           = reinterpret_cast<T *>(out.ptr());
-            for(int index_wei_d = wei_d_start, index_in_d = in_d_start; index_wei_d < wei_d_end; ++index_wei_d, ++index_in_d)
-            {
-                const auto in_ptr_d      = in_ptr_start + index_in_d * input_stride_d;
-                const auto weights_ptr_d = weights_ptr_start + index_wei_d * kernel_stride_d;
-                for(int index_wei_h = wei_h_start, index_in_h = in_h_start; index_wei_h < wei_h_end; ++index_wei_h, ++index_in_h)
-                {
-                    const T *const in_ptr_row      = in_ptr_d + index_in_h * input_stride_h;
-                    const T *const weights_ptr_row = weights_ptr_d + index_wei_h * kernel_stride_h;
-                    for(int index_wei_w = wei_w_start, index_in_w = in_w_start; index_wei_w < wei_w_end; ++index_wei_w, ++index_in_w)
-                    {
-                        const T    *in_ptr_mover      = in_ptr_row + index_in_w * input_stride_w;
-                        const T    *weights_ptr_mover = weights_ptr_row + index_wei_w * kernel_stride_w;
-                        int         index_c_in        = 0;
-                        vector_type out_temp_vec      = wrapper::vdup_n(static_cast<T>(0), tag_type());
-                        vector_type w_vec             = wrapper::vdup_n(static_cast<T>(0), tag_type());
-                        for(; index_c_in <= index_c_in_end - num_elems_read_per_iteration;
-                            index_c_in += num_elems_read_per_iteration, in_ptr_mover += num_elems_read_per_iteration)
-                        {
-                            const auto src_vec = wrapper::vloadq(in_ptr_mover);
-                            //Load Cin weights
-                            for(unsigned int k = 0; k < num_elems_read_per_iteration; ++k, weights_ptr_mover += index_c_out_end)
-                            {
-                                w_vec = wrapper::vsetlane(*weights_ptr_mover, w_vec, k);
-                            }
-                            out_temp_vec = wrapper::vmla(out_temp_vec, w_vec, src_vec);
-                        }
-                        out_temp += vreduce(out_temp_vec);
-                        for(; index_c_in < index_c_in_end; ++index_c_in, ++in_ptr_mover, weights_ptr_mover += index_c_out_end)
-                        {
-                            const auto src_val = *(in_ptr_mover);
-                            const auto w_val   = *(weights_ptr_mover);
-                            out_temp += src_val * w_val;
-                        }
-                    }
-                }
-            }
-            *(reinterpret_cast<T *>(out_ptr + id_w[0])) = (biases) ? out_temp + biases_ptr[id_w[0]] : out_temp;
-        },
-        wei);
-    },
-    out);
-}
-
-void CpuDirectConv3dKernel::configure(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *dst, const Conv3dInfo &conv_info)
-{
-    ARM_COMPUTE_UNUSED(biases);
-    ARM_COMPUTE_ERROR_ON_NULLPTR(src, weights, dst);
-
-    _conv_info = conv_info;
+    _conv_info  = conv_info;
+    _run_method = uk->ukernel;
+    _name       = std::string("CpuDirectConv3dKernel").append("/").append(uk->name);
 
     // Get convolved dimensions
-    TensorShape output_shape = misc::shape_calculator::compute_conv3d_shape(src->tensor_shape(), weights->tensor_shape(), conv_info);
+    TensorShape output_shape = misc::shape_calculator::compute_conv3d_shape(src0->tensor_shape(), src1->tensor_shape(), conv_info);
 
-    DataType data_type = src->data_type();
+    DataType data_type = src0->data_type();
 
     // Output auto inizialitation if not yet initialized
     auto_init_if_empty(*dst, output_shape, 1, data_type);
 
     // Perform validation step
-    ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(src, weights, biases, dst, conv_info));
+    ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(src0, src1, src2, dst, conv_info));
 
     // Configure kernel window
     Window win = calculate_max_window(*dst, Steps());
     ICpuKernel::configure(win);
 }
 
-Status CpuDirectConv3dKernel::validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const Conv3dInfo &conv_info)
+Status CpuDirectConv3dKernel::validate(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, const Conv3dInfo &conv_info)
 {
-    ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(src, weights, biases, dst, conv_info));
+    ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(src0, src1, src2, dst, conv_info));
 
     return Status{};
 }
@@ -289,35 +178,19 @@
     ARM_COMPUTE_UNUSED(info);
     ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
     ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(ICpuKernel::window(), window);
+    ARM_COMPUTE_ERROR_ON(_run_method == nullptr);
 
-    auto src     = tensors.get_const_tensor(TensorType::ACL_SRC_0);
-    auto weights = tensors.get_const_tensor(TensorType::ACL_SRC_1);
-    auto biases  = tensors.get_const_tensor(TensorType::ACL_SRC_2);
-    auto dst     = tensors.get_tensor(TensorType::ACL_DST);
+    auto src0 = tensors.get_const_tensor(TensorType::ACL_SRC_0);
+    auto src1 = tensors.get_const_tensor(TensorType::ACL_SRC_1);
+    auto src2 = tensors.get_const_tensor(TensorType::ACL_SRC_2);
+    auto dst  = tensors.get_tensor(TensorType::ACL_DST);
 
-    switch(src->info()->data_type())
-    {
-        case DataType::F32:
-        {
-            convolve_ndhwc<float>(window, src, weights, biases, dst);
-            break;
-        }
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-        case DataType::F16:
-        {
-            convolve_ndhwc<float16_t>(window, src, weights, biases, dst);
-            break;
-        }
-#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-        default:
-            ARM_COMPUTE_ERROR("Data type not supported");
-            break;
-    }
+    _run_method(src0, src1, src2, dst, _conv_info, window);
 }
 
 const char *CpuDirectConv3dKernel::name() const
 {
-    return "CpuDirectConv3dKernel";
+    return _name.c_str();
 }
 } // namespace kernels
 } // namespace cpu
diff --git a/src/cpu/kernels/CpuDirectConv3dKernel.h b/src/cpu/kernels/CpuDirectConv3dKernel.h
index c7dcb0f..fc64e85 100644
--- a/src/cpu/kernels/CpuDirectConv3dKernel.h
+++ b/src/cpu/kernels/CpuDirectConv3dKernel.h
@@ -39,10 +39,7 @@
 public:
     CpuDirectConv3dKernel() = default;
     ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(CpuDirectConv3dKernel);
-    /** Set the src, weights, and dst tensor info.
-     *
-     * Valid data layouts:
-     * - NDHWC
+    /** Set the src, weights, biases and dst tensor info.
      *
      * Valid data type configurations:
      * |src0           |src1               |src2   |dst            |
@@ -50,34 +47,35 @@
      * |F16            |F16                |F16    |F16            |
      * |F32            |F32                |F32    |F32            |
      *
-     * @param[in, out] src       Input tensor info.
-     * @param[in]      weights   Set of kernels to convolve the input volume.
+     * @param[in, out] src0      Input tensor info.
+     * @param[in]      src1      Set of kernels to convolve the input volume.
      *                           The 2nd dimension must be the same as the input's volume 1st dimension.
-     * @param[in]      biases    Set of biases. Can be nullptr.
+     * @param[in]      src2      Set of biases. Can be nullptr.
      * @param[out]     dst       Output tensor info.
      *                           The 1st dimensions must be equal to the 1st dimension of the @p kernels tensor.
      * @param[in]      conv_info Contains padding, stride, acitvation information.
      *
      */
-    void configure(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *dst, const Conv3dInfo &conv_info);
+    void configure(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, ITensorInfo *dst, const Conv3dInfo &conv_info);
     /** Static function to check if given info will lead to a valid configuration
      *
      * Similar to CpuDirectConv3dKernel::configure()
      *
      * @return a status
      */
-    static Status validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const Conv3dInfo &conv_info);
+    static Status validate(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, const Conv3dInfo &conv_info);
 
     // Inherited methods overridden:
     void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override;
     const char *name() const override;
 
 private:
-    /* Template function for convolution NDHWC */
-    template <typename T>
-    void convolve_ndhwc(const Window &window, const ITensor *src, const ITensor *weights, const ITensor *biases, ITensor *dst);
+    /* Template function for convolution 3d NDHWC */
+    using DirectConv3dKernelPtr = std::add_pointer<void(const ITensor *, const ITensor *, const ITensor *, ITensor *, const Conv3dInfo &, const Window &)>::type;
 
-    Conv3dInfo _conv_info{};
+    Conv3dInfo            _conv_info{};
+    DirectConv3dKernelPtr _run_method{ nullptr };
+    std::string           _name{};
 };
 } // namespace kernels
 } // namespace cpu
diff --git a/src/cpu/kernels/conv3d/neon/list.h b/src/cpu/kernels/conv3d/neon/list.h
new file mode 100644
index 0000000..b24785a
--- /dev/null
+++ b/src/cpu/kernels/conv3d/neon/list.h
@@ -0,0 +1,176 @@
+/*
+ * Copyright (c) 2021 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef SRC_CORE_NEON_KERNELS_CONV3D_LIST_H
+#define SRC_CORE_NEON_KERNELS_CONV3D_LIST_H
+
+#include "arm_compute/core/Types.h"
+#include "arm_compute/core/utils/misc/Traits.h"
+#include "arm_compute/runtime/FunctionDescriptors.h"
+#include "src/core/NEON/wrapper/wrapper.h"
+#include "src/core/helpers/WindowHelpers.h"
+
+namespace arm_compute
+{
+namespace cpu
+{
+template <typename T>
+void directconv3d_float_neon_ndhwc(const ITensor *src0, const ITensor *src1, const ITensor *src2, ITensor *dst, const Conv3dInfo &conv_info, const Window &window)
+{
+    const ITensor *src     = src0;
+    const ITensor *weights = src1;
+    const ITensor *biases  = src2;
+
+    using vtype                                = wrapper::traits::neon_bitvector<T, wrapper::traits::BitWidth::W128>;
+    using vector_type                          = typename vtype::type;
+    using tag_type                             = typename vtype::tag_type;
+    constexpr int num_elems_read_per_iteration = 16 / sizeof(T);
+
+    // Scalar quantities (N D H W Cin)
+    const int element_size   = src->info()->element_size();
+    const int input_stride_w = src->info()->strides_in_bytes().y() / element_size;
+    const int input_stride_h = src->info()->strides_in_bytes().z() / element_size;
+    const int input_stride_d = src->info()->strides_in_bytes()[3] / element_size;
+    const int input_stride_n = src->info()->strides_in_bytes()[4] / element_size;
+    const int input_dim_w    = src->info()->dimension(1);
+    const int input_dim_h    = src->info()->dimension(2);
+    const int input_dim_d    = src->info()->dimension(3);
+
+    // Kernel info (D H W Cin Cout)
+    const unsigned int kernel_stride_w = weights->info()->strides_in_bytes()[2] / element_size;
+    const unsigned int kernel_stride_h = weights->info()->strides_in_bytes()[3] / element_size;
+    const unsigned int kernel_stride_d = weights->info()->strides_in_bytes()[4] / element_size;
+    const int          kernel_dim_w    = weights->info()->dimension(2);
+    const int          kernel_dim_h    = weights->info()->dimension(3);
+    const int          kernel_dim_d    = weights->info()->dimension(4);
+
+    // Convolution padding and stride
+    const int conv_pad_top   = conv_info.padding.top;
+    const int conv_pad_left  = conv_info.padding.left;
+    const int conv_pad_front = conv_info.padding.front;
+    const int conv_stride_w  = conv_info.stride.width;
+    const int conv_stride_h  = conv_info.stride.height;
+    const int conv_stride_d  = conv_info.stride.depth;
+
+    // Setup input window for the output iterator
+    Window window_out = window;
+    window_out.set(Window::DimX, Window::Dimension(0, 1, 1));
+
+    // Setup input window for the weights iterator
+    Window window_w = calculate_max_window(*weights->info(), Steps());
+    window_w.set(Window::DimY, Window::Dimension(0, 1, 1));
+    window_w.set(Window::DimZ, Window::Dimension(0, 1, 1));
+    window_w.set(Window::DimW, Window::Dimension(0, 1, 1));
+    window_w.set(4, Window::Dimension(0, 1, 1));
+
+    Iterator out(dst, window_out);
+    Iterator wei(weights, window_w);
+
+    const T *biases_ptr = nullptr;
+    if(biases != nullptr)
+    {
+        biases_ptr = reinterpret_cast<T *>(biases->buffer() + biases->info()->offset_first_element_in_bytes());
+    }
+    execute_window_loop(window_out, [&](const Coordinates & id)
+    {
+        // We are computing the theoretical input starting points
+        const int in_w_start_t = static_cast<int>(id.y()) * conv_stride_w - conv_pad_left;
+        const int in_h_start_t = static_cast<int>(id.z()) * conv_stride_h - conv_pad_top;
+        const int in_d_start_t = static_cast<int>(id[3]) * conv_stride_d - conv_pad_front;
+        const int in_w_end_t   = in_w_start_t + kernel_dim_w;
+        const int in_h_end_t   = in_h_start_t + kernel_dim_h;
+        const int in_d_end_t   = in_d_start_t + kernel_dim_d;
+
+        // We are computing the valid initial and ending input points by checking the borders
+        const int in_w_start = std::max(in_w_start_t, 0);
+        const int in_h_start = std::max(in_h_start_t, 0);
+        const int in_d_start = std::max(in_d_start_t, 0);
+        const int in_w_end   = std::min(in_w_end_t, input_dim_w);
+        const int in_h_end   = std::min(in_h_end_t, input_dim_h);
+        const int in_d_end   = std::min(in_d_end_t, input_dim_d);
+
+        // We use the input points to select the valid weight points to use
+        const int wei_w_start = in_w_start - in_w_start_t;
+        const int wei_h_start = in_h_start - in_h_start_t;
+        const int wei_d_start = in_d_start - in_d_start_t;
+        const int wei_w_end   = kernel_dim_w - (in_w_end_t - in_w_end);
+        const int wei_h_end   = kernel_dim_h - (in_h_end_t - in_h_end);
+        const int wei_d_end   = kernel_dim_d - (in_d_end_t - in_d_end);
+
+        const int      index_c_out_end = weights->info()->dimension(0);
+        const int      index_c_in_end  = weights->info()->dimension(1);
+        const T *const in_ptr_start    = reinterpret_cast<const T *>(src->buffer() + src->info()->offset_first_element_in_bytes()) + id[4] * input_stride_n;
+
+        execute_window_loop(window_w, [&](const Coordinates & id_w)
+        {
+            /*
+            * This is the loop in the weights, and it goes along OFM (output feature map)
+            */
+            const auto weights_ptr_start = reinterpret_cast<const T *>(wei.ptr());
+            T          out_temp          = static_cast<T>(0);
+            T         *out_ptr           = reinterpret_cast<T *>(out.ptr());
+            for(int index_wei_d = wei_d_start, index_in_d = in_d_start; index_wei_d < wei_d_end; ++index_wei_d, ++index_in_d)
+            {
+                const auto in_ptr_d      = in_ptr_start + index_in_d * input_stride_d;
+                const auto weights_ptr_d = weights_ptr_start + index_wei_d * kernel_stride_d;
+                for(int index_wei_h = wei_h_start, index_in_h = in_h_start; index_wei_h < wei_h_end; ++index_wei_h, ++index_in_h)
+                {
+                    const T *const in_ptr_row      = in_ptr_d + index_in_h * input_stride_h;
+                    const T *const weights_ptr_row = weights_ptr_d + index_wei_h * kernel_stride_h;
+                    for(int index_wei_w = wei_w_start, index_in_w = in_w_start; index_wei_w < wei_w_end; ++index_wei_w, ++index_in_w)
+                    {
+                        const T    *in_ptr_mover      = in_ptr_row + index_in_w * input_stride_w;
+                        const T    *weights_ptr_mover = weights_ptr_row + index_wei_w * kernel_stride_w;
+                        int         index_c_in        = 0;
+                        vector_type out_temp_vec      = wrapper::vdup_n(static_cast<T>(0), tag_type());
+                        vector_type w_vec             = wrapper::vdup_n(static_cast<T>(0), tag_type());
+                        for(; index_c_in <= index_c_in_end - num_elems_read_per_iteration;
+                            index_c_in += num_elems_read_per_iteration, in_ptr_mover += num_elems_read_per_iteration)
+                        {
+                            const auto src_vec = wrapper::vloadq(in_ptr_mover);
+                            //Load Cin weights
+                            for(unsigned int k = 0; k < num_elems_read_per_iteration; ++k, weights_ptr_mover += index_c_out_end)
+                            {
+                                w_vec = wrapper::vsetlane(*weights_ptr_mover, w_vec, k);
+                            }
+                            out_temp_vec = wrapper::vmla(out_temp_vec, w_vec, src_vec);
+                        }
+                        out_temp += vreduce(out_temp_vec);
+                        for(; index_c_in < index_c_in_end; ++index_c_in, ++in_ptr_mover, weights_ptr_mover += index_c_out_end)
+                        {
+                            const auto src_val = *(in_ptr_mover);
+                            const auto w_val   = *(weights_ptr_mover);
+                            out_temp += src_val * w_val;
+                        }
+                    }
+                }
+            }
+            *(reinterpret_cast<T *>(out_ptr + id_w[0])) = (biases_ptr != nullptr) ? out_temp + biases_ptr[id_w[0]] : out_temp;
+        },
+        wei);
+    },
+    out);
+}
+} // namespace cpu
+} // namespace arm_compute
+#endif // SRC_CORE_NEON_KERNELS_CONV3D_LIST_H
\ No newline at end of file
diff --git a/src/cpu/operators/CpuDirectConv3d.cpp b/src/cpu/operators/CpuDirectConv3d.cpp
index 3827910..aa74e42 100644
--- a/src/cpu/operators/CpuDirectConv3d.cpp
+++ b/src/cpu/operators/CpuDirectConv3d.cpp
@@ -40,10 +40,10 @@
 {
 }
 
-void CpuDirectConv3d::configure(ITensorInfo *src, ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *dst, const Conv3dInfo conv_info)
+void CpuDirectConv3d::configure(ITensorInfo *src0, ITensorInfo *src1, const ITensorInfo *src2, ITensorInfo *dst, const Conv3dInfo conv_info)
 {
-    ARM_COMPUTE_LOG_PARAMS(src, weights, biases, dst, conv_info);
-    ARM_COMPUTE_ERROR_ON(src->data_layout() != DataLayout::NDHWC);
+    ARM_COMPUTE_LOG_PARAMS(src0, src1, src2, dst, conv_info);
+    ARM_COMPUTE_ERROR_ON(src0->data_layout() != DataLayout::NDHWC);
 
     _conv_kernel = std::make_unique<kernels::CpuDirectConv3dKernel>();
 
@@ -55,7 +55,7 @@
 
     _dim_split = Window::DimY;
 
-    _conv_kernel->configure(src, weights, biases, dst, conv_info);
+    _conv_kernel->configure(src0, src1, src2, dst, conv_info);
 
     //Configure Activation Layer
     _is_activationlayer_enabled = conv_info.act_info.enabled();
@@ -66,16 +66,12 @@
     }
 }
 
-Status CpuDirectConv3d::validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const Conv3dInfo conv_info)
+Status CpuDirectConv3d::validate(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, const Conv3dInfo conv_info)
 {
-    ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src, weights, dst);
-
-    // output might not be initialized since it can be an intermediate tensor of another layer
-    DataType   data_type = src->data_type();
-    TensorInfo accumulator(dst->clone()->set_is_resizable(true).reset_padding().set_data_type(data_type));
+    ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src0, src1, dst);
 
     // Validate Convolution kernel
-    ARM_COMPUTE_RETURN_ON_ERROR(kernels::CpuDirectConv3dKernel::validate(src, weights, biases, &accumulator, conv_info));
+    ARM_COMPUTE_RETURN_ON_ERROR(kernels::CpuDirectConv3dKernel::validate(src0, src1, src2, dst, conv_info));
 
     if(conv_info.act_info.enabled())
     {
diff --git a/src/cpu/operators/CpuDirectConv3d.h b/src/cpu/operators/CpuDirectConv3d.h
index ad04dee..f7c3099 100644
--- a/src/cpu/operators/CpuDirectConv3d.h
+++ b/src/cpu/operators/CpuDirectConv3d.h
@@ -57,23 +57,31 @@
     ~CpuDirectConv3d();
     /** Set the input, weights, biases and output tensor info.
      *
-     * @param[in, out] src       Input tensor info.
-     * @param[in]      weights   Set of kernels to convolve the input volume.
-     *                           The 2nd dimension must be the same as the input's volume 1st dimension.
-     *                           Data type supported: Same as @p src.
-     * @param[in]      biases    Set of biases. Can be nullptr. Data type supported: Same as @p src.
+     * Valid data layouts:
+     * - NDHWC
+     *
+     * Valid data type configurations:
+     * |src0           |src1               |src2   |dst            |
+     * |:--------------|:------------------|:------|:--------------|
+     * |F16            |F16                |F16    |F16            |
+     * |F32            |F32                |F32    |F32            |
+     *
+     * @param[in, out] src0      Input tensor info.
+     * @param[in]      src1      Set of kernels to convolve the input volume.
+     *                           The 2nd dimension must be the same as the src0's volume 1st dimension.
+     * @param[in]      src2      Set of biases. Can be nullptr.
      * @param[out]     dst       Output tensor info.
      *                           The 1st dimensions must be equal to the 1st dimension of the @p kernels tensor.
      * @param[in]      conv_info Contains padding, stride, acitvation information.
      */
-    void configure(ITensorInfo *src, ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *dst, const Conv3dInfo conv_info);
+    void configure(ITensorInfo *src0, ITensorInfo *src1, const ITensorInfo *src2, ITensorInfo *dst, const Conv3dInfo conv_info);
     /** Static function to check if given info will lead to a valid configuration
      *
      * Similar to CpuDirectConv3d::configure()
      *
      * @return a status
      */
-    static Status validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const Conv3dInfo conv_info);
+    static Status validate(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, const Conv3dInfo conv_info);
 
     // Inherited methods overridden:
     void run(ITensorPack &tensors) override;
diff --git a/src/gpu/cl/kernels/ClDirectConv3dKernel.cpp b/src/gpu/cl/kernels/ClDirectConv3dKernel.cpp
index 1c4326b..88e73dc 100644
--- a/src/gpu/cl/kernels/ClDirectConv3dKernel.cpp
+++ b/src/gpu/cl/kernels/ClDirectConv3dKernel.cpp
@@ -37,36 +37,36 @@
 {
 namespace
 {
-Status validate_arguments(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const Conv3dInfo &conv3d_info)
+Status validate_arguments(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, const Conv3dInfo &conv3d_info)
 {
-    ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_LAYOUT(src, weights, dst);
-    ARM_COMPUTE_RETURN_ERROR_ON_MSG(src->data_layout() != DataLayout::NDHWC, "Only NDHWC layout supported");
+    ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_LAYOUT(src0, src1, dst);
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG(src0->data_layout() != DataLayout::NDHWC, "Only NDHWC layout supported");
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv3d_info.act_info.enabled(), "Fused activation not supported");
 
-    ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(src);
-    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, DataType::F16, DataType::F32);
-    ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, weights);
+    ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(src0);
+    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src0, 1, DataType::F16, DataType::F32);
+    ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src0, src1);
 
-    ARM_COMPUTE_RETURN_ERROR_ON_MSG(weights->dimension(1) != src->dimension(0), "Weights feature map dimension should match the respective src's one");
-    ARM_COMPUTE_RETURN_ERROR_ON_MSG(weights->num_dimensions() > 5, "Weights can be at most 5 dimensional");
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG(src1->dimension(1) != src0->dimension(0), "Weights feature map dimension should match the respective src's one");
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG(src1->num_dimensions() > 5, "Weights can be at most 5 dimensional");
 
-    ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(2) > (src->dimension(1) + conv3d_info.padding.left + conv3d_info.padding.right));
-    ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(3) > (src->dimension(2) + conv3d_info.padding.top + conv3d_info.padding.bottom));
-    ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(4) > (src->dimension(3) + conv3d_info.padding.front + conv3d_info.padding.back));
+    ARM_COMPUTE_RETURN_ERROR_ON(src1->dimension(2) > (src0->dimension(1) + conv3d_info.padding.left + conv3d_info.padding.right));
+    ARM_COMPUTE_RETURN_ERROR_ON(src1->dimension(3) > (src0->dimension(2) + conv3d_info.padding.top + conv3d_info.padding.bottom));
+    ARM_COMPUTE_RETURN_ERROR_ON(src1->dimension(4) > (src0->dimension(3) + conv3d_info.padding.front + conv3d_info.padding.back));
 
-    if(biases != nullptr)
+    if(src2 != nullptr)
     {
-        ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(weights, biases);
-        ARM_COMPUTE_RETURN_ERROR_ON_MSG(biases->dimension(0) != weights->dimension(0), "Biases size and number of dst feature maps should match");
-        ARM_COMPUTE_RETURN_ERROR_ON_MSG(biases->num_dimensions() > 1, "Biases should be one dimensional");
+        ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src1, src2);
+        ARM_COMPUTE_RETURN_ERROR_ON_MSG(src2->dimension(0) != src1->dimension(0), "Biases size and number of dst feature maps should match");
+        ARM_COMPUTE_RETURN_ERROR_ON_MSG(src2->num_dimensions() > 1, "Biases should be one dimensional");
     }
 
     // Checks performed when dst is configured
     if(dst->total_size() != 0)
     {
-        ARM_COMPUTE_RETURN_ERROR_ON_MSG(dst->dimension(0) != weights->dimension(0), "Weights and dst OFMs should match");
-        ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(dst->tensor_shape(), misc::shape_calculator::compute_conv3d_shape(src->tensor_shape(), weights->tensor_shape(), conv3d_info));
-        ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, dst);
+        ARM_COMPUTE_RETURN_ERROR_ON_MSG(dst->dimension(0) != src1->dimension(0), "Weights and dst OFMs should match");
+        ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(dst->tensor_shape(), misc::shape_calculator::compute_conv3d_shape(src0->tensor_shape(), src1->tensor_shape(), conv3d_info));
+        ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src0, dst);
     }
 
     return Status{};
@@ -78,27 +78,27 @@
     _type = CLKernelType::DIRECT;
 }
 
-void ClDirectConv3dKernel::configure(const CLCompileContext &compile_context, const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *dst,
+void ClDirectConv3dKernel::configure(const CLCompileContext &compile_context, const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, ITensorInfo *dst,
                                      const Conv3dInfo &conv3d_info)
 {
-    ARM_COMPUTE_ERROR_ON_NULLPTR(src, weights, dst);
+    ARM_COMPUTE_ERROR_ON_NULLPTR(src0, src1, dst);
 
     // Perform validation
-    ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(src, weights, biases, dst, conv3d_info));
+    ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(src0, src1, src2, dst, conv3d_info));
 
     // Create window and update padding
-    const DataType data_type      = src->data_type();
-    const size_t   src_width      = src->dimension(1);
-    const size_t   src_height     = src->dimension(2);
-    const size_t   src_depth      = src->dimension(3);
-    const size_t   src_channels   = src->dimension(0);
+    const DataType data_type      = src0->data_type();
+    const size_t   src_width      = src0->dimension(1);
+    const size_t   src_height     = src0->dimension(2);
+    const size_t   src_depth      = src0->dimension(3);
+    const size_t   src_channels   = src0->dimension(0);
     const size_t   dst_width      = dst->dimension(1);
     const size_t   dst_height     = dst->dimension(2);
     const size_t   dst_depth      = dst->dimension(3);
     const size_t   dst_channels   = dst->dimension(0);
-    const size_t   weights_width  = weights->dimension(2);
-    const size_t   weights_height = weights->dimension(3);
-    const size_t   weights_depth  = weights->dimension(4);
+    const size_t   weights_width  = src1->dimension(2);
+    const size_t   weights_height = src1->dimension(3);
+    const size_t   weights_depth  = src1->dimension(4);
     const size_t   pad_left       = conv3d_info.padding.left;
     const size_t   pad_top        = conv3d_info.padding.top;
     const size_t   pad_front      = conv3d_info.padding.front;
@@ -108,7 +108,7 @@
 
     const size_t n0               = std::min(dst->dimension(0), static_cast<size_t>(4u));
     const size_t m0               = (dst->tensor_shape()[0] > 16) ? ((data_type == DataType::F32) ? 2U : 4U) : 1U;
-    const size_t k0               = adjust_vec_size(8u, src->dimension(0));
+    const size_t k0               = adjust_vec_size(8u, src0->dimension(0));
     const size_t partial_store_n0 = dst->dimension(0) % n0;
 
     CLBuildOptions build_options;
@@ -136,7 +136,7 @@
     build_options.add_option("-DM0=" + support::cpp11::to_string(m0));
     build_options.add_option("-DK0=" + support::cpp11::to_string(k0));
     build_options.add_option("-DPARTIAL_N0=" + support::cpp11::to_string(partial_store_n0));
-    build_options.add_option_if(biases != nullptr, std::string("-DHAS_BIAS"));
+    build_options.add_option_if(src2 != nullptr, std::string("-DHAS_BIAS"));
 
     std::string kernel_name = "direct_convolution3d_ndhwc";
     _kernel                 = create_kernel(compile_context, kernel_name, build_options.options());
@@ -169,9 +169,9 @@
     _config_id += support::cpp11::to_string(dst_channels);
 }
 
-Status ClDirectConv3dKernel::validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const Conv3dInfo &conv3d_info)
+Status ClDirectConv3dKernel::validate(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, const Conv3dInfo &conv3d_info)
 {
-    ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(src, weights, biases, dst, conv3d_info));
+    ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(src0, src1, src2, dst, conv3d_info));
     return Status{};
 }
 
diff --git a/src/gpu/cl/kernels/ClDirectConv3dKernel.h b/src/gpu/cl/kernels/ClDirectConv3dKernel.h
index 9ac8f0d..485c900 100644
--- a/src/gpu/cl/kernels/ClDirectConv3dKernel.h
+++ b/src/gpu/cl/kernels/ClDirectConv3dKernel.h
@@ -61,21 +61,21 @@
      * |F32            |F32            |F32    |F32            |
      *
      * @param[in]  compile_context The compile context to be used.
-     * @param[in]  src             Source tensor. 4 lower dimensions represent a single src [IFM, width, height, depth],
+     * @param[in]  src0            Source tensor. 4 lower dimensions represent a single src [IFM, width, height, depth],
      *                             while every optional dimension from 5 and above represent a batch of srcs.
-     * @param[in]  weights         Weights tensor. Weights are 5D tensor with dimensions [OFM, IFM, kernel_w, kernel_h, kernel_d].
-     * @param[in]  biases          Biases tensor. Shared biases supported. Biases are 1D tensor with dimensions [OFM].
+     * @param[in]  src1            Weights tensor. Weights are 5D tensor with dimensions [OFM, IFM, kernel_w, kernel_h, kernel_d].
+     * @param[in]  src2            Biases tensor. Shared biases supported. Biases are 1D tensor with dimensions [OFM].
      * @param[out] dst             Destination tensor. 4 lower dimensions represent a single dst [OFM, width, height, depth], while the rest represent batch of dsts.
      * @param[in]  conv3d_info     Contains strides, padding, rounding, activation, dilation and fast math information. Activation and fast math are currently unused.
      */
-    void configure(const CLCompileContext &compile_context, const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *dst, const Conv3dInfo &conv3d_info);
+    void configure(const CLCompileContext &compile_context, const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, ITensorInfo *dst, const Conv3dInfo &conv3d_info);
     /** Static function to check if given info will lead to a valid configuration
      *
      * Similar to ClDirectConv3dKernel::configure()
      *
      * @return a status
      */
-    static Status validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const Conv3dInfo &conv3d_info);
+    static Status validate(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, const Conv3dInfo &conv3d_info);
 
     // Inherited methods overridden:
     void run_op(ITensorPack &tensors, const Window &window, cl::CommandQueue &queue) override;
diff --git a/src/gpu/cl/operators/ClDirectConv3d.cpp b/src/gpu/cl/operators/ClDirectConv3d.cpp
index d101658..5d37f07 100644
--- a/src/gpu/cl/operators/ClDirectConv3d.cpp
+++ b/src/gpu/cl/operators/ClDirectConv3d.cpp
@@ -30,19 +30,19 @@
 {
 namespace opencl
 {
-void ClDirectConv3d::configure(const CLCompileContext &compile_context, const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *dst, const Conv3dInfo &conv3d_info)
+void ClDirectConv3d::configure(const CLCompileContext &compile_context, const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, ITensorInfo *dst, const Conv3dInfo &conv3d_info)
 {
-    ARM_COMPUTE_ERROR_ON_NULLPTR(src);
+    ARM_COMPUTE_ERROR_ON_NULLPTR(src0);
 
     // Configure direct convolution 3d kernel
     auto k = std::make_unique<kernels::ClDirectConv3dKernel>();
-    k->configure(compile_context, src, weights, biases, dst, conv3d_info);
+    k->configure(compile_context, src0, src1, src2, dst, conv3d_info);
     _direct_conv3d_kernel = std::move(k);
 }
 
-Status ClDirectConv3d::validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const Conv3dInfo &conv3d_info)
+Status ClDirectConv3d::validate(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, const Conv3dInfo &conv3d_info)
 {
-    ARM_COMPUTE_RETURN_ON_ERROR(kernels::ClDirectConv3dKernel::validate(src, weights, biases, dst, conv3d_info));
+    ARM_COMPUTE_RETURN_ON_ERROR(kernels::ClDirectConv3dKernel::validate(src0, src1, src2, dst, conv3d_info));
     return Status{};
 }
 
diff --git a/src/gpu/cl/operators/ClDirectConv3d.h b/src/gpu/cl/operators/ClDirectConv3d.h
index ce9135b..d8ffefc 100644
--- a/src/gpu/cl/operators/ClDirectConv3d.h
+++ b/src/gpu/cl/operators/ClDirectConv3d.h
@@ -57,15 +57,15 @@
      * |F32            |F32            |F32    |F32            |
      *
      * @param[in]  compile_context The compile context to be used.
-     * @param[in]  src             Source tensor. 4 lower dimensions represent a single src [IFM, width, height, depth],
+     * @param[in]  src0            Source tensor. 4 lower dimensions represent a single src [IFM, width, height, depth],
      *                             while every optional dimension from 5 and above represent a batch of srcs.
-     * @param[in]  weights         Weights tensor. Weights are 5D tensor with dimensions [OFM, IFM, kernel_w, kernel_h, kernel_d].
-     * @param[in]  biases          Biases tensor. Shared biases supported. Biases are 1D tensor with dimensions [OFM].
+     * @param[in]  src1            Weights tensor. Weights are 5D tensor with dimensions [OFM, IFM, kernel_w, kernel_h, kernel_d].
+     * @param[in]  src2            Biases tensor. Shared biases supported. Biases are 1D tensor with dimensions [OFM].
      * @param[out] dst             Destination tensor. 4 lower dimensions represent a single dst [OFM, width, height, depth], while the rest represent batch of dsts.
      * @param[in]  conv3d_info     Contains strides, padding, rounding, activation, dilation and fast math information. Activation and fast math are currently unused.
      *
      */
-    void configure(const CLCompileContext &compile_context, const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *dst, const Conv3dInfo &conv3d_info);
+    void configure(const CLCompileContext &compile_context, const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, ITensorInfo *dst, const Conv3dInfo &conv3d_info);
 
     /** Static function to check if given info will lead to a valid configuration
      *
@@ -73,7 +73,7 @@
      *
      * @return a status
      */
-    static Status validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const Conv3dInfo &conv3d_info);
+    static Status validate(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, const Conv3dInfo &conv3d_info);
 
     // Inherited method overridden
     void run(ITensorPack &tensors) override;
diff --git a/src/runtime/NEON/functions/NEConv3D.cpp b/src/runtime/NEON/functions/NEConv3D.cpp
index b5e2e2a..3bb66c4 100644
--- a/src/runtime/NEON/functions/NEConv3D.cpp
+++ b/src/runtime/NEON/functions/NEConv3D.cpp
@@ -27,7 +27,6 @@
 #include "arm_compute/core/Utils.h"
 #include "arm_compute/core/Validate.h"
 #include "src/common/utils/Log.h"
-#include "src/core/helpers/MemoryHelpers.h"
 #include "src/cpu/operators/CpuDirectConv3d.h"
 
 namespace arm_compute
@@ -58,7 +57,7 @@
     f->configure(input->info(), weights->info(), ((biases != nullptr) ? biases->info() : nullptr), output->info(), conv_info);
     _impl->op = std::move(f);
 
-    if(_impl->op)
+    if(_impl->op != nullptr)
     {
         _impl->run_pack = { { ACL_SRC_0, input }, { ACL_SRC_1, weights }, { ACL_SRC_2, biases }, { ACL_DST, output } };
     }
@@ -73,7 +72,7 @@
 
 void NEConv3D::run()
 {
-    if(_impl->op)
+    if(_impl->op != nullptr)
     {
         _impl->op->run(_impl->run_pack);
     }