Refactor Dequantize to enable FP16 kernel in v8a multi_isa builds

Signed-off-by: Ramy Elgammal <ramy.elgammal@arm.com>
COMPMID-7058
Change-Id: I9c6d18a8fddaf335bcd1e8dd562fa3838c1ca4b2
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11561
Benchmark: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Pablo Marquez Tello <pablo.tello@arm.com>
diff --git a/Android.bp b/Android.bp
index fd8afff..1f1e591 100644
--- a/Android.bp
+++ b/Android.bp
@@ -489,6 +489,8 @@
         "src/cpu/kernels/depthwiseconv2d/generic/neon/impl.cpp",
         "src/cpu/kernels/depthwiseconv2d/generic/neon/qasymm8.cpp",
         "src/cpu/kernels/depthwiseconv2d/generic/neon/qasymm8_signed.cpp",
+        "src/cpu/kernels/dequantize/generic/neon/fp16.cpp",
+        "src/cpu/kernels/dequantize/generic/neon/fp32.cpp",
         "src/cpu/kernels/directconv2d/nchw/all.cpp",
         "src/cpu/kernels/directconv2d/nchw/fp16.cpp",
         "src/cpu/kernels/directconv2d/nhwc/neon/fp16.cpp",
@@ -557,7 +559,6 @@
         "src/cpu/kernels/quantize/generic/neon/fp16.cpp",
         "src/cpu/kernels/quantize/generic/neon/fp32.cpp",
         "src/cpu/kernels/quantize/generic/neon/integer.cpp",
-        "src/cpu/kernels/quantize/generic/neon/vquantize.cpp",
         "src/cpu/kernels/range/generic/neon/fp16.cpp",
         "src/cpu/kernels/range/generic/neon/fp32.cpp",
         "src/cpu/kernels/range/generic/neon/integer.cpp",
diff --git a/filelist.json b/filelist.json
index 3ee5304..15449b4 100644
--- a/filelist.json
+++ b/filelist.json
@@ -1415,7 +1415,11 @@
             "src/cpu/operators/CpuDequantize.cpp",
             "src/cpu/kernels/CpuDequantizeKernel.cpp",
             "src/runtime/NEON/functions/NEDequantizationLayer.cpp"
-          ]
+          ],
+          "neon":{
+            "fp32":["src/cpu/kernels/dequantize/generic/neon/fp32.cpp"],
+            "fp16":["src/cpu/kernels/dequantize/generic/neon/fp16.cpp"]
+          }
         }
       },
       "DetectionPostProcess": {
@@ -2093,8 +2097,7 @@
           "common": [
             "src/cpu/operators/CpuQuantize.cpp",
             "src/cpu/kernels/CpuQuantizeKernel.cpp",
-            "src/runtime/NEON/functions/NEQuantizationLayer.cpp",
-            "src/cpu/kernels/quantize/generic/neon/vquantize.cpp"
+            "src/runtime/NEON/functions/NEQuantizationLayer.cpp"
           ],
           "neon":{
             "fp32":["src/cpu/kernels/quantize/generic/neon/fp32.cpp"],
diff --git a/src/BUILD.bazel b/src/BUILD.bazel
index 499e564..f270824 100644
--- a/src/BUILD.bazel
+++ b/src/BUILD.bazel
@@ -753,6 +753,8 @@
 	"cpu/kernels/depthwiseconv2d/generic/neon/impl.cpp",
 	"cpu/kernels/depthwiseconv2d/generic/neon/qasymm8.cpp",
 	"cpu/kernels/depthwiseconv2d/generic/neon/qasymm8_signed.cpp",
+	"cpu/kernels/dequantize/generic/neon/fp16.cpp",
+	"cpu/kernels/dequantize/generic/neon/fp32.cpp",
 	"cpu/kernels/directconv2d/nchw/all.cpp",
 	"cpu/kernels/directconv2d/nchw/fp16.cpp",
 	"cpu/kernels/directconv2d/nhwc/neon/fp16.cpp",
@@ -821,7 +823,6 @@
 	"cpu/kernels/quantize/generic/neon/fp16.cpp",
 	"cpu/kernels/quantize/generic/neon/fp32.cpp",
 	"cpu/kernels/quantize/generic/neon/integer.cpp",
-	"cpu/kernels/quantize/generic/neon/vquantize.cpp",
 	"cpu/kernels/range/generic/neon/fp16.cpp",
 	"cpu/kernels/range/generic/neon/fp32.cpp",
 	"cpu/kernels/range/generic/neon/integer.cpp",
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index 8d63ab5..87c5f8b 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -744,6 +744,8 @@
 	cpu/kernels/depthwiseconv2d/generic/neon/impl.cpp
 	cpu/kernels/depthwiseconv2d/generic/neon/qasymm8.cpp
 	cpu/kernels/depthwiseconv2d/generic/neon/qasymm8_signed.cpp
+	cpu/kernels/dequantize/generic/neon/fp16.cpp
+	cpu/kernels/dequantize/generic/neon/fp32.cpp
 	cpu/kernels/directconv2d/nchw/all.cpp
 	cpu/kernels/directconv2d/nchw/fp16.cpp
 	cpu/kernels/directconv2d/nhwc/neon/fp16.cpp
@@ -812,7 +814,6 @@
 	cpu/kernels/quantize/generic/neon/fp16.cpp
 	cpu/kernels/quantize/generic/neon/fp32.cpp
 	cpu/kernels/quantize/generic/neon/integer.cpp
-	cpu/kernels/quantize/generic/neon/vquantize.cpp
 	cpu/kernels/range/generic/neon/fp16.cpp
 	cpu/kernels/range/generic/neon/fp32.cpp
 	cpu/kernels/range/generic/neon/integer.cpp
diff --git a/src/cpu/kernels/CpuDequantizeKernel.cpp b/src/cpu/kernels/CpuDequantizeKernel.cpp
index d17128b..6154ad3 100644
--- a/src/cpu/kernels/CpuDequantizeKernel.cpp
+++ b/src/cpu/kernels/CpuDequantizeKernel.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017-2021 Arm Limited.
+ * Copyright (c) 2017-2021, 2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -29,12 +29,14 @@
 #include "arm_compute/core/Validate.h"
 #include "arm_compute/core/Window.h"
 
+#include "src/core/common/Registrars.h"
 #include "src/core/CPP/Validate.h"
 #include "src/core/helpers/AutoConfiguration.h"
 #include "src/core/helpers/WindowHelpers.h"
 #include "src/core/NEON/NEAsymm.h"
 #include "src/core/NEON/NESymm.h"
 #include "src/core/NEON/wrapper/wrapper.h"
+#include "src/cpu/kernels/dequantize/generic/neon/list.h"
 
 #include <arm_neon.h>
 
@@ -62,301 +64,6 @@
 
     return Status{};
 }
-
-template <typename T>
-inline void store_result(T *ptr, const float32x4x4_t &v)
-{
-    ARM_COMPUTE_UNUSED(ptr, v);
-}
-
-template <>
-inline void store_result<float>(float *ptr, const float32x4x4_t &v)
-{
-    wrapper::vstore(ptr, v.val[0]);
-    wrapper::vstore(ptr + 4, v.val[1]);
-    wrapper::vstore(ptr + 8, v.val[2]);
-    wrapper::vstore(ptr + 12, v.val[3]);
-}
-
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-template <>
-inline void store_result<float16_t>(float16_t *ptr, const float32x4x4_t &v)
-{
-    wrapper::vstore(ptr, vcombine_f16(vcvt_f16_f32(v.val[0]), vcvt_f16_f32(v.val[1])));
-    wrapper::vstore(ptr + 8, vcombine_f16(vcvt_f16_f32(v.val[2]), vcvt_f16_f32(v.val[3])));
-}
-#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
-
-template <typename T>
-inline void store_result(T *ptr, const float32x4x2_t &v)
-{
-    ARM_COMPUTE_UNUSED(ptr, v);
-}
-
-template <>
-inline void store_result<float>(float *ptr, const float32x4x2_t &v)
-{
-    wrapper::vstore(ptr, v.val[0]);
-    wrapper::vstore(ptr + 4, v.val[1]);
-}
-
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-template <>
-inline void store_result<float16_t>(float16_t *ptr, const float32x4x2_t &v)
-{
-    wrapper::vstore(ptr, vcombine_f16(vcvt_f16_f32(v.val[0]), vcvt_f16_f32(v.val[1])));
-}
-#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
-
-template <typename TOut, typename TIn>
-void run_dequantization_qasymm8(const ITensor *input, ITensor *output, const Window &window)
-{
-    const UniformQuantizationInfo &qinfo  = input->info()->quantization_info().uniform();
-    const float                    scale  = qinfo.scale;
-    const int32_t                  offset = qinfo.offset;
-
-    const int  window_step_x  = 16;
-    const auto window_start_x = static_cast<int>(window.x().start());
-    const auto window_end_x   = static_cast<int>(window.x().end());
-
-    // Collapse window and reset first dimension to handle tail calculations manually
-    Window win_collapsed = window.collapse_if_possible(window, Window::DimZ);
-    win_collapsed.set(Window::DimX, Window::Dimension(0, 1, 1));
-
-    // Create iterators
-    Iterator in(input, win_collapsed);
-    Iterator out(output, win_collapsed);
-
-    execute_window_loop(
-        win_collapsed,
-        [&](const Coordinates &)
-        {
-            const auto in_ptr  = reinterpret_cast<const TIn *>(in.ptr());
-            const auto out_ptr = reinterpret_cast<TOut *>(out.ptr());
-
-            int x = window_start_x;
-            for (; x <= (window_end_x - window_step_x); x += window_step_x)
-            {
-                const auto vin  = wrapper::vloadq(in_ptr + x);
-                const auto vdeq = vdequantize(vin, scale, offset);
-
-                store_result(reinterpret_cast<TOut *>(out_ptr + x), vdeq);
-            }
-
-            // Compute left-over elements
-            for (; x < window_end_x; ++x)
-            {
-                auto val       = *(in_ptr + x);
-                *(out_ptr + x) = static_cast<TOut>(Qasymm8QuantizationHelper<TIn>::dequantize(val, qinfo));
-            }
-        },
-        in, out);
-}
-
-template <typename T>
-void run_dequantization_qsymm8_per_channel_nchw(const ITensor *input, ITensor *output, const Window &window)
-{
-    const auto scale = input->info()->quantization_info().scale();
-
-    const int  window_step_x  = 16;
-    const auto window_start_x = static_cast<int>(window.x().start());
-    const auto window_end_x   = static_cast<int>(window.x().end());
-
-    // Reset first dimension to handle tail calculations manually
-    Window win(window);
-    win.set(Window::DimX, Window::Dimension(0, 1, 1));
-
-    // Create iterators
-    Iterator in(input, win);
-    Iterator out(output, win);
-
-    execute_window_loop(
-        win,
-        [&](const Coordinates &id)
-        {
-            const auto in_ptr  = reinterpret_cast<const int8_t *>(in.ptr());
-            const auto out_ptr = reinterpret_cast<T *>(out.ptr());
-
-            int x = window_start_x;
-            for (; x <= (window_end_x - window_step_x); x += window_step_x)
-            {
-                const auto vin  = wrapper::vloadq(in_ptr + x);
-                const auto vdeq = vdequantize(vin, scale[id.z()]);
-
-                store_result<T>(reinterpret_cast<T *>(out_ptr + x), vdeq);
-            }
-
-            // Compute left-over elements
-            for (; x < window_end_x; ++x)
-            {
-                int8_t val     = *(in_ptr + x);
-                *(out_ptr + x) = static_cast<T>(dequantize(val, scale[id.z()]));
-            }
-        },
-        in, out);
-}
-
-template <typename T>
-void run_dequantization_qsymm8_per_channel_nhwc(const ITensor *input, ITensor *output, const Window &window)
-{
-    const auto scale = input->info()->quantization_info().scale();
-
-    const int  window_step_x  = 16;
-    const auto window_start_x = static_cast<int>(window.x().start());
-    const auto window_end_x   = static_cast<int>(window.x().end());
-
-    // Reset first dimension to handle tail calculations manually
-    Window win(window);
-    win.set(Window::DimX, Window::Dimension(0, 1, 1));
-
-    // Create iterators
-    Iterator in(input, win);
-    Iterator out(output, win);
-
-    execute_window_loop(
-        win,
-        [&](const Coordinates &)
-        {
-            const auto in_ptr  = reinterpret_cast<const int8_t *>(in.ptr());
-            const auto out_ptr = reinterpret_cast<T *>(out.ptr());
-
-            int x = window_start_x;
-            for (; x <= (window_end_x - window_step_x); x += window_step_x)
-            {
-                const float32x4x4_t vscale = {{scale[x + 0], scale[x + 1], scale[x + 2], scale[x + 3], scale[x + 4],
-                                               scale[x + 5], scale[x + 6], scale[x + 7], scale[x + 8], scale[x + 9],
-                                               scale[x + 10], scale[x + 11], scale[x + 12], scale[x + 13],
-                                               scale[x + 14], scale[x + 15]}};
-                const auto          vin    = wrapper::vloadq(in_ptr + x);
-                const auto          vdeq   = vdequantize(vin, vscale);
-
-                store_result<T>(reinterpret_cast<T *>(out_ptr + x), vdeq);
-            }
-
-            // Compute left-over elements
-            for (; x < window_end_x; ++x)
-            {
-                int8_t val     = *(in_ptr + x);
-                *(out_ptr + x) = static_cast<T>(dequantize(val, scale[x]));
-            }
-        },
-        in, out);
-}
-
-template <typename T>
-void run_dequantization_qsymm8(const ITensor *input, ITensor *output, const Window &window)
-{
-    const UniformQuantizationInfo &qinfo = input->info()->quantization_info().uniform();
-    const float                    scale = qinfo.scale;
-
-    const int  window_step_x  = 16;
-    const auto window_start_x = static_cast<int>(window.x().start());
-    const auto window_end_x   = static_cast<int>(window.x().end());
-
-    // Collapse window and reset first dimension to handle tail calculations manually
-    Window win_collapsed = window.collapse_if_possible(window, Window::DimZ);
-    win_collapsed.set(Window::DimX, Window::Dimension(0, 1, 1));
-
-    // Create iterators
-    Iterator in(input, win_collapsed);
-    Iterator out(output, win_collapsed);
-
-    execute_window_loop(
-        win_collapsed,
-        [&](const Coordinates &)
-        {
-            const auto in_ptr  = reinterpret_cast<const int8_t *>(in.ptr());
-            const auto out_ptr = reinterpret_cast<T *>(out.ptr());
-
-            int x = window_start_x;
-            for (; x <= (window_end_x - window_step_x); x += window_step_x)
-            {
-                const auto vin  = wrapper::vloadq(in_ptr + x);
-                const auto vdeq = vdequantize(vin, scale);
-
-                store_result<T>(reinterpret_cast<T *>(out_ptr + x), vdeq);
-            }
-
-            // Compute left-over elements
-            for (; x < window_end_x; ++x)
-            {
-                int8_t val     = *(in_ptr + x);
-                *(out_ptr + x) = static_cast<T>(dequantize(val, scale));
-            }
-        },
-        in, out);
-}
-
-template <typename T>
-void run_dequantization_qsymm16(const ITensor *input, ITensor *output, const Window &window)
-{
-    const UniformQuantizationInfo &qinfo = input->info()->quantization_info().uniform();
-    const float                    scale = qinfo.scale;
-
-    const int  window_step_x  = 8;
-    const auto window_start_x = static_cast<int>(window.x().start());
-    const auto window_end_x   = static_cast<int>(window.x().end());
-
-    // Collapse window and reset first dimension to handle tail calculations manually
-    Window win_collapsed = window.collapse_if_possible(window, Window::DimZ);
-    win_collapsed.set(Window::DimX, Window::Dimension(0, 1, 1));
-
-    // Create iterators
-    Iterator in(input, win_collapsed);
-    Iterator out(output, win_collapsed);
-
-    execute_window_loop(
-        win_collapsed,
-        [&](const Coordinates &)
-        {
-            const auto in_ptr  = reinterpret_cast<const int16_t *>(in.ptr());
-            const auto out_ptr = reinterpret_cast<T *>(out.ptr());
-
-            int x = window_start_x;
-            for (; x <= (window_end_x - window_step_x); x += window_step_x)
-            {
-                const auto vin  = wrapper::vloadq(in_ptr + x);
-                const auto vdeq = vdequantize_int16(vin, scale);
-
-                store_result<T>(reinterpret_cast<T *>(out_ptr + x), vdeq);
-            }
-
-            // Compute left-over elements
-            for (; x < window_end_x; ++x)
-            {
-                int16_t val    = *(in_ptr + x);
-                *(out_ptr + x) = static_cast<T>(dequantize_qsymm16(val, scale));
-            }
-        },
-        in, out);
-}
-
-template <typename T>
-void run_dequantization_core(const ITensor *input, ITensor *output, const Window &window)
-{
-    switch (input->info()->data_type())
-    {
-        case DataType::QASYMM8:
-            run_dequantization_qasymm8<T, uint8_t>(input, output, window);
-            break;
-        case DataType::QASYMM8_SIGNED:
-            run_dequantization_qasymm8<T, int8_t>(input, output, window);
-            break;
-        case DataType::QSYMM8_PER_CHANNEL:
-            input->info()->data_layout() == DataLayout::NHWC
-                ? run_dequantization_qsymm8_per_channel_nhwc<T>(input, output, window)
-                : run_dequantization_qsymm8_per_channel_nchw<T>(input, output, window);
-            break;
-        case DataType::QSYMM8:
-            run_dequantization_qsymm8<T>(input, output, window);
-            break;
-        case DataType::QSYMM16:
-            run_dequantization_qsymm16<T>(input, output, window);
-            break;
-        default:
-            ARM_COMPUTE_ERROR("Unsupported data type.");
-    }
-}
 } // namespace
 
 void CpuDequantizeKernel::configure(const ITensorInfo *src, ITensorInfo *dst)
@@ -370,6 +77,20 @@
     auto_init_if_empty(*dst, src->tensor_shape(), 1, DataType::F32);
 
     ICpuKernel::configure(win);
+
+    switch (dst->data_type())
+    {
+        case DataType::F32:
+            _func = REGISTER_FP32_NEON(fp32_run_dequantization_core);
+            break;
+#ifdef ARM_COMPUTE_ENABLE_FP16
+        case DataType::F16:
+            _func = REGISTER_FP32_NEON(fp16_run_dequantization_core);
+            break;
+#endif /* ARM_COMPUTE_ENABLE_FP16 */
+        default:
+            ARM_COMPUTE_ERROR("Unsupported data type.");
+    }
 }
 
 Status CpuDequantizeKernel::validate(const ITensorInfo *src, const ITensorInfo *dst)
@@ -386,20 +107,7 @@
 
     const auto src = tensors.get_const_tensor(TensorType::ACL_SRC);
     auto       dst = tensors.get_tensor(TensorType::ACL_DST);
-
-    switch (dst->info()->data_type())
-    {
-        case DataType::F32:
-            run_dequantization_core<float>(src, dst, window);
-            break;
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-        case DataType::F16:
-            run_dequantization_core<float16_t>(src, dst, window);
-            break;
-#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
-        default:
-            ARM_COMPUTE_ERROR("Unsupported data type.");
-    }
+    _func(src, dst, window);
 }
 const char *CpuDequantizeKernel::name() const
 {
diff --git a/src/cpu/kernels/CpuDequantizeKernel.h b/src/cpu/kernels/CpuDequantizeKernel.h
index 6ed5858..d8b6444 100644
--- a/src/cpu/kernels/CpuDequantizeKernel.h
+++ b/src/cpu/kernels/CpuDequantizeKernel.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017-2022 Arm Limited.
+ * Copyright (c) 2017-2022, 2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -21,8 +21,8 @@
  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  * SOFTWARE.
  */
-#ifndef ARM_COMPUTE_CPU_DEQUANTIZE_KERNEL_H
-#define ARM_COMPUTE_CPU_DEQUANTIZE_KERNEL_H
+#ifndef ACL_SRC_CPU_KERNELS_CPUDEQUANTIZEKERNEL_H
+#define ACL_SRC_CPU_KERNELS_CPUDEQUANTIZEKERNEL_H
 
 #include "src/core/common/Macros.h"
 #include "src/cpu/ICpuKernel.h"
@@ -56,8 +56,16 @@
     // Inherited methods overridden:
     void        run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override;
     const char *name() const override;
+
+private:
+    /** Common signature for all the specialised @ref CpuDequantizeKernel functions
+     *
+     * @param[in] window Region on which to execute the kernel.
+     */
+    using DequantizeFunctionExecutorPtr = void (*)(const ITensor *input, ITensor *output, const Window &window);
+    DequantizeFunctionExecutorPtr _func{nullptr};
 };
 } // namespace kernels
 } // namespace cpu
 } // namespace arm_compute
-#endif /* ARM_COMPUTE_CPU_DEQUANTIZE_KERNEL_H */
+#endif // ACL_SRC_CPU_KERNELS_CPUDEQUANTIZEKERNEL_H
diff --git a/src/cpu/kernels/quantize/generic/neon/vquantize.cpp b/src/cpu/kernels/dequantize/generic/neon/fp16.cpp
similarity index 75%
copy from src/cpu/kernels/quantize/generic/neon/vquantize.cpp
copy to src/cpu/kernels/dequantize/generic/neon/fp16.cpp
index d40702b..caffdf5 100644
--- a/src/cpu/kernels/quantize/generic/neon/vquantize.cpp
+++ b/src/cpu/kernels/dequantize/generic/neon/fp16.cpp
@@ -21,21 +21,17 @@
  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  * SOFTWARE.
  */
-#include "impl.h"
+#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS)
+#include "src/cpu/kernels/dequantize/generic/neon/impl.h"
+
 namespace arm_compute
 {
 namespace cpu
 {
-template <>
-vector_type<uint8_t> vquantize_qasymm8<uint8_t>(const float32x4x4_t &qv, const UniformQuantizationInfo &qi)
+void fp16_run_dequantization_core(const ITensor *input, ITensor *output, const Window &window)
 {
-    return vquantize(qv, qi);
-}
-
-template <>
-vector_type<int8_t> vquantize_qasymm8<int8_t>(const float32x4x4_t &qv, const UniformQuantizationInfo &qi)
-{
-    return vquantize_signed(qv, qi);
+    run_dequantization_core<float16_t>(input, output, window);
 }
 } // namespace cpu
 } // namespace arm_compute
+#endif /* defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS) */
diff --git a/src/cpu/kernels/quantize/generic/neon/vquantize.cpp b/src/cpu/kernels/dequantize/generic/neon/fp32.cpp
similarity index 78%
rename from src/cpu/kernels/quantize/generic/neon/vquantize.cpp
rename to src/cpu/kernels/dequantize/generic/neon/fp32.cpp
index d40702b..58e987b 100644
--- a/src/cpu/kernels/quantize/generic/neon/vquantize.cpp
+++ b/src/cpu/kernels/dequantize/generic/neon/fp32.cpp
@@ -21,21 +21,15 @@
  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  * SOFTWARE.
  */
-#include "impl.h"
+#include "src/cpu/kernels/dequantize/generic/neon/impl.h"
+
 namespace arm_compute
 {
 namespace cpu
 {
-template <>
-vector_type<uint8_t> vquantize_qasymm8<uint8_t>(const float32x4x4_t &qv, const UniformQuantizationInfo &qi)
+void fp32_run_dequantization_core(const ITensor *input, ITensor *output, const Window &window)
 {
-    return vquantize(qv, qi);
-}
-
-template <>
-vector_type<int8_t> vquantize_qasymm8<int8_t>(const float32x4x4_t &qv, const UniformQuantizationInfo &qi)
-{
-    return vquantize_signed(qv, qi);
+    run_dequantization_core<float>(input, output, window);
 }
 } // namespace cpu
 } // namespace arm_compute
diff --git a/src/cpu/kernels/dequantize/generic/neon/impl.h b/src/cpu/kernels/dequantize/generic/neon/impl.h
new file mode 100644
index 0000000..7197d4d
--- /dev/null
+++ b/src/cpu/kernels/dequantize/generic/neon/impl.h
@@ -0,0 +1,340 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef ACL_SRC_CPU_KERNELS_DEQUANTIZE_GENERIC_NEON_IMPL_H
+#define ACL_SRC_CPU_KERNELS_DEQUANTIZE_GENERIC_NEON_IMPL_H
+
+#include "arm_compute/core/Helpers.h"
+#include "arm_compute/core/Window.h"
+
+#include "src/core/NEON/NEAsymm.h"
+#include "src/core/NEON/NESymm.h"
+#include "src/core/NEON/wrapper/wrapper.h"
+#include "src/cpu/kernels/dequantize/generic/neon/list.h"
+
+#include <arm_neon.h>
+
+namespace arm_compute
+{
+namespace cpu
+{
+
+template <typename T>
+inline void store_result(T *ptr, const float32x4x4_t &v)
+{
+    ARM_COMPUTE_UNUSED(ptr, v);
+}
+
+template <>
+inline void store_result<float>(float *ptr, const float32x4x4_t &v)
+{
+    wrapper::vstore(ptr, v.val[0]);
+    wrapper::vstore(ptr + 4, v.val[1]);
+    wrapper::vstore(ptr + 8, v.val[2]);
+    wrapper::vstore(ptr + 12, v.val[3]);
+}
+
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+template <>
+inline void store_result<float16_t>(float16_t *ptr, const float32x4x4_t &v)
+{
+    wrapper::vstore(ptr, vcombine_f16(vcvt_f16_f32(v.val[0]), vcvt_f16_f32(v.val[1])));
+    wrapper::vstore(ptr + 8, vcombine_f16(vcvt_f16_f32(v.val[2]), vcvt_f16_f32(v.val[3])));
+}
+#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
+
+template <typename T>
+inline void store_result(T *ptr, const float32x4x2_t &v)
+{
+    ARM_COMPUTE_UNUSED(ptr, v);
+}
+
+template <>
+inline void store_result<float>(float *ptr, const float32x4x2_t &v)
+{
+    wrapper::vstore(ptr, v.val[0]);
+    wrapper::vstore(ptr + 4, v.val[1]);
+}
+
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+template <>
+inline void store_result<float16_t>(float16_t *ptr, const float32x4x2_t &v)
+{
+    wrapper::vstore(ptr, vcombine_f16(vcvt_f16_f32(v.val[0]), vcvt_f16_f32(v.val[1])));
+}
+#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
+
+template <typename TOut, typename TIn>
+void run_dequantization_qasymm8(const ITensor *input, ITensor *output, const Window &window)
+{
+    const UniformQuantizationInfo &qinfo  = input->info()->quantization_info().uniform();
+    const float                    scale  = qinfo.scale;
+    const int32_t                  offset = qinfo.offset;
+
+    const int  window_step_x  = 16;
+    const auto window_start_x = static_cast<int>(window.x().start());
+    const auto window_end_x   = static_cast<int>(window.x().end());
+
+    // Collapse window and reset first dimension to handle tail calculations manually
+    Window win_collapsed = window.collapse_if_possible(window, Window::DimZ);
+    win_collapsed.set(Window::DimX, Window::Dimension(0, 1, 1));
+
+    // Create iterators
+    Iterator in(input, win_collapsed);
+    Iterator out(output, win_collapsed);
+
+    execute_window_loop(
+        win_collapsed,
+        [&](const Coordinates &)
+        {
+            const auto in_ptr  = reinterpret_cast<const TIn *>(in.ptr());
+            const auto out_ptr = reinterpret_cast<TOut *>(out.ptr());
+
+            int x = window_start_x;
+            for (; x <= (window_end_x - window_step_x); x += window_step_x)
+            {
+                const auto vin  = wrapper::vloadq(in_ptr + x);
+                const auto vdeq = vdequantize(vin, scale, offset);
+
+                store_result(reinterpret_cast<TOut *>(out_ptr + x), vdeq);
+            }
+
+            // Compute left-over elements
+            for (; x < window_end_x; ++x)
+            {
+                auto val       = *(in_ptr + x);
+                *(out_ptr + x) = static_cast<TOut>(Qasymm8QuantizationHelper<TIn>::dequantize(val, qinfo));
+            }
+        },
+        in, out);
+}
+
+template <typename T>
+void run_dequantization_qsymm8_per_channel_nchw(const ITensor *input, ITensor *output, const Window &window)
+{
+    const auto scale = input->info()->quantization_info().scale();
+
+    const int  window_step_x  = 16;
+    const auto window_start_x = static_cast<int>(window.x().start());
+    const auto window_end_x   = static_cast<int>(window.x().end());
+
+    // Reset first dimension to handle tail calculations manually
+    Window win(window);
+    win.set(Window::DimX, Window::Dimension(0, 1, 1));
+
+    // Create iterators
+    Iterator in(input, win);
+    Iterator out(output, win);
+
+    execute_window_loop(
+        win,
+        [&](const Coordinates &id)
+        {
+            const auto in_ptr  = reinterpret_cast<const int8_t *>(in.ptr());
+            const auto out_ptr = reinterpret_cast<T *>(out.ptr());
+
+            int x = window_start_x;
+            for (; x <= (window_end_x - window_step_x); x += window_step_x)
+            {
+                const auto vin  = wrapper::vloadq(in_ptr + x);
+                const auto vdeq = vdequantize(vin, scale[id.z()]);
+
+                store_result<T>(reinterpret_cast<T *>(out_ptr + x), vdeq);
+            }
+
+            // Compute left-over elements
+            for (; x < window_end_x; ++x)
+            {
+                int8_t val     = *(in_ptr + x);
+                *(out_ptr + x) = static_cast<T>(dequantize(val, scale[id.z()]));
+            }
+        },
+        in, out);
+}
+
+template <typename T>
+void run_dequantization_qsymm8_per_channel_nhwc(const ITensor *input, ITensor *output, const Window &window)
+{
+    const auto scale = input->info()->quantization_info().scale();
+
+    const int  window_step_x  = 16;
+    const auto window_start_x = static_cast<int>(window.x().start());
+    const auto window_end_x   = static_cast<int>(window.x().end());
+
+    // Reset first dimension to handle tail calculations manually
+    Window win(window);
+    win.set(Window::DimX, Window::Dimension(0, 1, 1));
+
+    // Create iterators
+    Iterator in(input, win);
+    Iterator out(output, win);
+
+    execute_window_loop(
+        win,
+        [&](const Coordinates &)
+        {
+            const auto in_ptr  = reinterpret_cast<const int8_t *>(in.ptr());
+            const auto out_ptr = reinterpret_cast<T *>(out.ptr());
+
+            int x = window_start_x;
+            for (; x <= (window_end_x - window_step_x); x += window_step_x)
+            {
+                const float32x4x4_t vscale = {{scale[x + 0], scale[x + 1], scale[x + 2], scale[x + 3], scale[x + 4],
+                                               scale[x + 5], scale[x + 6], scale[x + 7], scale[x + 8], scale[x + 9],
+                                               scale[x + 10], scale[x + 11], scale[x + 12], scale[x + 13],
+                                               scale[x + 14], scale[x + 15]}};
+                const auto          vin    = wrapper::vloadq(in_ptr + x);
+                const auto          vdeq   = vdequantize(vin, vscale);
+
+                store_result<T>(reinterpret_cast<T *>(out_ptr + x), vdeq);
+            }
+
+            // Compute left-over elements
+            for (; x < window_end_x; ++x)
+            {
+                int8_t val     = *(in_ptr + x);
+                *(out_ptr + x) = static_cast<T>(dequantize(val, scale[x]));
+            }
+        },
+        in, out);
+}
+
+template <typename T>
+void run_dequantization_qsymm8(const ITensor *input, ITensor *output, const Window &window)
+{
+    const UniformQuantizationInfo &qinfo = input->info()->quantization_info().uniform();
+    const float                    scale = qinfo.scale;
+
+    const int  window_step_x  = 16;
+    const auto window_start_x = static_cast<int>(window.x().start());
+    const auto window_end_x   = static_cast<int>(window.x().end());
+
+    // Collapse window and reset first dimension to handle tail calculations manually
+    Window win_collapsed = window.collapse_if_possible(window, Window::DimZ);
+    win_collapsed.set(Window::DimX, Window::Dimension(0, 1, 1));
+
+    // Create iterators
+    Iterator in(input, win_collapsed);
+    Iterator out(output, win_collapsed);
+
+    execute_window_loop(
+        win_collapsed,
+        [&](const Coordinates &)
+        {
+            const auto in_ptr  = reinterpret_cast<const int8_t *>(in.ptr());
+            const auto out_ptr = reinterpret_cast<T *>(out.ptr());
+
+            int x = window_start_x;
+            for (; x <= (window_end_x - window_step_x); x += window_step_x)
+            {
+                const auto vin  = wrapper::vloadq(in_ptr + x);
+                const auto vdeq = vdequantize(vin, scale);
+
+                store_result<T>(reinterpret_cast<T *>(out_ptr + x), vdeq);
+            }
+
+            // Compute left-over elements
+            for (; x < window_end_x; ++x)
+            {
+                int8_t val     = *(in_ptr + x);
+                *(out_ptr + x) = static_cast<T>(dequantize(val, scale));
+            }
+        },
+        in, out);
+}
+
+template <typename T>
+void run_dequantization_qsymm16(const ITensor *input, ITensor *output, const Window &window)
+{
+    const UniformQuantizationInfo &qinfo = input->info()->quantization_info().uniform();
+    const float                    scale = qinfo.scale;
+
+    const int  window_step_x  = 8;
+    const auto window_start_x = static_cast<int>(window.x().start());
+    const auto window_end_x   = static_cast<int>(window.x().end());
+
+    // Collapse window and reset first dimension to handle tail calculations manually
+    Window win_collapsed = window.collapse_if_possible(window, Window::DimZ);
+    win_collapsed.set(Window::DimX, Window::Dimension(0, 1, 1));
+
+    // Create iterators
+    Iterator in(input, win_collapsed);
+    Iterator out(output, win_collapsed);
+
+    execute_window_loop(
+        win_collapsed,
+        [&](const Coordinates &)
+        {
+            const auto in_ptr  = reinterpret_cast<const int16_t *>(in.ptr());
+            const auto out_ptr = reinterpret_cast<T *>(out.ptr());
+
+            int x = window_start_x;
+            for (; x <= (window_end_x - window_step_x); x += window_step_x)
+            {
+                const auto vin  = wrapper::vloadq(in_ptr + x);
+                const auto vdeq = vdequantize_int16(vin, scale);
+
+                store_result<T>(reinterpret_cast<T *>(out_ptr + x), vdeq);
+            }
+
+            // Compute left-over elements
+            for (; x < window_end_x; ++x)
+            {
+                int16_t val    = *(in_ptr + x);
+                *(out_ptr + x) = static_cast<T>(dequantize_qsymm16(val, scale));
+            }
+        },
+        in, out);
+}
+
+template <typename T>
+void run_dequantization_core(const ITensor *input, ITensor *output, const Window &window)
+{
+    switch (input->info()->data_type())
+    {
+        case DataType::QASYMM8:
+            run_dequantization_qasymm8<T, uint8_t>(input, output, window);
+            break;
+        case DataType::QASYMM8_SIGNED:
+            run_dequantization_qasymm8<T, int8_t>(input, output, window);
+            break;
+        case DataType::QSYMM8_PER_CHANNEL:
+            input->info()->data_layout() == DataLayout::NHWC
+                ? run_dequantization_qsymm8_per_channel_nhwc<T>(input, output, window)
+                : run_dequantization_qsymm8_per_channel_nchw<T>(input, output, window);
+            break;
+        case DataType::QSYMM8:
+            run_dequantization_qsymm8<T>(input, output, window);
+            break;
+        case DataType::QSYMM16:
+            run_dequantization_qsymm16<T>(input, output, window);
+            break;
+        default:
+            ARM_COMPUTE_ERROR("Unsupported data type.");
+    }
+}
+
+} // namespace cpu
+} // namespace arm_compute
+
+#endif // ACL_SRC_CPU_KERNELS_DEQUANTIZE_GENERIC_NEON_IMPL_H
diff --git a/src/cpu/kernels/quantize/generic/neon/impl_fp32.h b/src/cpu/kernels/dequantize/generic/neon/list.h
similarity index 71%
rename from src/cpu/kernels/quantize/generic/neon/impl_fp32.h
rename to src/cpu/kernels/dequantize/generic/neon/list.h
index 00ae242..678eb2c 100644
--- a/src/cpu/kernels/quantize/generic/neon/impl_fp32.h
+++ b/src/cpu/kernels/dequantize/generic/neon/list.h
@@ -21,24 +21,23 @@
  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  * SOFTWARE.
  */
-#ifndef ACL_SRC_CPU_KERNELS_QUANTIZE_GENERIC_NEON_IMPL_FP32_H
-#define ACL_SRC_CPU_KERNELS_QUANTIZE_GENERIC_NEON_IMPL_FP32_H
+#ifndef ACL_SRC_CPU_KERNELS_DEQUANTIZE_GENERIC_NEON_LIST_H
+#define ACL_SRC_CPU_KERNELS_DEQUANTIZE_GENERIC_NEON_LIST_H
 
-#include "src/core/helpers/WindowHelpers.h"
-#include "src/core/NEON/NEAsymm.h"
+#include "arm_compute/core/Helpers.h"
 
 namespace arm_compute
 {
 namespace cpu
 {
-inline float32x4x4_t load_value(const float *input_ptr)
-{
-    return {wrapper::vloadq(input_ptr), wrapper::vloadq(input_ptr + 4), wrapper::vloadq(input_ptr + 8),
-            wrapper::vloadq(input_ptr + 12)};
-}
+
+#define DECLARE_DEQUANTIZE_KERNEL(func_name) void func_name(const ITensor *input, ITensor *output, const Window &window)
+
+DECLARE_DEQUANTIZE_KERNEL(fp32_run_dequantization_core);
+DECLARE_DEQUANTIZE_KERNEL(fp16_run_dequantization_core);
+
+#undef DECLARE_DEQUANTIZE_KERNEL
 
 } // namespace cpu
 } // namespace arm_compute
-
-#include "impl.h"
-#endif // ACL_SRC_CPU_KERNELS_QUANTIZE_GENERIC_NEON_IMPL_FP32_H
+#endif // ACL_SRC_CPU_KERNELS_DEQUANTIZE_GENERIC_NEON_LIST_H
diff --git a/src/cpu/kernels/quantize/generic/neon/fp16.cpp b/src/cpu/kernels/quantize/generic/neon/fp16.cpp
index 456a3bd..37bfb5b 100644
--- a/src/cpu/kernels/quantize/generic/neon/fp16.cpp
+++ b/src/cpu/kernels/quantize/generic/neon/fp16.cpp
@@ -22,7 +22,7 @@
  * SOFTWARE.
  */
 #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS)
-#include "src/cpu/kernels/quantize/generic/neon/impl_fp16.h"
+#include "src/cpu/kernels/quantize/generic/neon/impl.h"
 
 namespace arm_compute
 {
diff --git a/src/cpu/kernels/quantize/generic/neon/fp32.cpp b/src/cpu/kernels/quantize/generic/neon/fp32.cpp
index 15f52b2..0cba332 100644
--- a/src/cpu/kernels/quantize/generic/neon/fp32.cpp
+++ b/src/cpu/kernels/quantize/generic/neon/fp32.cpp
@@ -21,7 +21,7 @@
  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  * SOFTWARE.
  */
-#include "src/cpu/kernels/quantize/generic/neon/impl_fp32.h"
+#include "src/cpu/kernels/quantize/generic/neon/impl.h"
 
 namespace arm_compute
 {
diff --git a/src/cpu/kernels/quantize/generic/neon/impl.h b/src/cpu/kernels/quantize/generic/neon/impl.h
index 1861fca..9954a76 100644
--- a/src/cpu/kernels/quantize/generic/neon/impl.h
+++ b/src/cpu/kernels/quantize/generic/neon/impl.h
@@ -43,11 +43,39 @@
     return arm_compute::convert_to_float32x4x4<Tx16_t>(wrapper::vloadq(input_ptr));
 }
 
+template <>
+inline float32x4x4_t load_value(const float *input_ptr)
+{
+    return {wrapper::vloadq(input_ptr), wrapper::vloadq(input_ptr + 4), wrapper::vloadq(input_ptr + 8),
+            wrapper::vloadq(input_ptr + 12)};
+}
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+template <>
+inline float32x4x4_t load_value(const float16_t *input_ptr)
+{
+    return {vcvt_f32_f16(wrapper::vload(input_ptr)), vcvt_f32_f16(wrapper::vload(input_ptr + 4)),
+            vcvt_f32_f16(wrapper::vload(input_ptr + 8)), vcvt_f32_f16(wrapper::vload(input_ptr + 12))};
+}
+
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+
 template <typename element_type>
 using vector_type = wrapper::traits::neon_vector_t<element_type, window_step>;
 
 template <typename quantized_type>
-vector_type<quantized_type> vquantize_qasymm8(const float32x4x4_t &qv, const UniformQuantizationInfo &qi);
+inline vector_type<quantized_type> vquantize_qasymm8(const float32x4x4_t &qv, const UniformQuantizationInfo &qi);
+
+template <>
+inline vector_type<uint8_t> vquantize_qasymm8<uint8_t>(const float32x4x4_t &qv, const UniformQuantizationInfo &qi)
+{
+    return vquantize(qv, qi);
+}
+
+template <>
+inline vector_type<int8_t> vquantize_qasymm8<int8_t>(const float32x4x4_t &qv, const UniformQuantizationInfo &qi)
+{
+    return vquantize_signed(qv, qi);
+}
 
 template <typename TOut, typename = typename std::enable_if<std::is_signed<TOut>::value, bool>::type>
 inline int8x16_t recombine_8_16(int16x8_t lower, int16x8_t upper)
diff --git a/src/cpu/kernels/quantize/generic/neon/impl_fp16.h b/src/cpu/kernels/quantize/generic/neon/impl_fp16.h
deleted file mode 100644
index 47f1b90..0000000
--- a/src/cpu/kernels/quantize/generic/neon/impl_fp16.h
+++ /dev/null
@@ -1,46 +0,0 @@
-/*
- * Copyright (c) 2024 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-#ifndef ACL_SRC_CPU_KERNELS_QUANTIZE_GENERIC_NEON_IMPL_FP16_H
-#define ACL_SRC_CPU_KERNELS_QUANTIZE_GENERIC_NEON_IMPL_FP16_H
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-
-#include "src/core/helpers/WindowHelpers.h"
-#include "src/core/NEON/NEAsymm.h"
-
-namespace arm_compute
-{
-namespace cpu
-{
-
-inline float32x4x4_t load_value(const float16_t *input_ptr)
-{
-    return {vcvt_f32_f16(wrapper::vload(input_ptr)), vcvt_f32_f16(wrapper::vload(input_ptr + 4)),
-            vcvt_f32_f16(wrapper::vload(input_ptr + 8)), vcvt_f32_f16(wrapper::vload(input_ptr + 12))};
-}
-
-} // namespace cpu
-} // namespace arm_compute
-#include "impl.h"
-#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-#endif // ACL_SRC_CPU_KERNELS_QUANTIZE_GENERIC_NEON_IMPL_FP16_H
diff --git a/src/cpu/kernels/reduction_layer/generic/neon/fp16.cpp b/src/cpu/kernels/reduction_layer/generic/neon/fp16.cpp
index 41584e9..143bb54 100644
--- a/src/cpu/kernels/reduction_layer/generic/neon/fp16.cpp
+++ b/src/cpu/kernels/reduction_layer/generic/neon/fp16.cpp
@@ -23,7 +23,7 @@
  */
 #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS)
 
-#include "src/cpu/kernels/reduction_layer/generic/neon/impl_fp16.h"
+#include "src/cpu/kernels/reduction_layer/generic/neon/impl.h"
 
 namespace arm_compute
 {
diff --git a/src/cpu/kernels/reduction_layer/generic/neon/impl.h b/src/cpu/kernels/reduction_layer/generic/neon/impl.h
index 611d83c..3fa821d 100644
--- a/src/cpu/kernels/reduction_layer/generic/neon/impl.h
+++ b/src/cpu/kernels/reduction_layer/generic/neon/impl.h
@@ -26,7 +26,6 @@
 
 #include "arm_compute/core/Coordinates.h"
 #include "arm_compute/core/Helpers.h"
-#include "arm_compute/core/ITensor.h"
 #include "arm_compute/core/TensorInfo.h"
 
 #include "src/core/NEON/NEMath.h"
@@ -247,6 +246,91 @@
     return (res - 0xFFFFFFFF);
 }
 
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+template <>
+uint32x4x4_t inline calculate_index(
+    uint32_t idx, float16x8_t a, float16x8_t b, uint32x4x4_t c, ReductionOperation op, int axis)
+{
+    uint32x4x2_t mask{0};
+    uint16x8_t   mask_u16{0};
+    if (op == ReductionOperation::ARG_IDX_MIN)
+    {
+        mask_u16 = wrapper::vcgt(b, a);
+    }
+    else
+    {
+        mask_u16 = wrapper::vclt(b, a);
+    }
+    mask.val[0]          = wrapper::vmovl(wrapper::vgetlow(mask_u16));
+    mask.val[1]          = wrapper::vmovl(wrapper::vgethigh(mask_u16));
+    uint32x4x2_t vec_idx = {{{idx + 0, idx + 1, idx + 2, idx + 3}, {idx + 4, idx + 5, idx + 6, idx + 7}}};
+    if (axis != 0)
+    {
+        vec_idx.val[0] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
+        vec_idx.val[1] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
+    }
+    uint32x4x4_t res = {wrapper::vbsl(mask.val[0], vec_idx.val[0], c.val[0]),
+                        wrapper::vbsl(mask.val[1], vec_idx.val[1], c.val[1]), 0, 0};
+
+    return res;
+}
+
+// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value.
+inline float16x4_t calculate_min(float16x8_t in)
+{
+    auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in));
+    pmin      = wrapper::vpmin(pmin, pmin);
+    return wrapper::vpmin(pmin, pmin);
+}
+// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value.
+inline float16x4_t calculate_max(float16x8_t in)
+{
+    auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in));
+    pmax      = wrapper::vpmax(pmax, pmax);
+    return wrapper::vpmax(pmax, pmax);
+}
+
+template <>
+inline uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, float16x8_t vec_res_value, ReductionOperation op)
+{
+    uint32x4x2_t res_idx_mask{0};
+    uint32x4_t   mask_ones = vdupq_n_u32(0xFFFFFFFF);
+    uint16x8_t   mask_u16;
+    if (op == ReductionOperation::ARG_IDX_MIN)
+    {
+        auto pmin = calculate_min(vec_res_value);
+        mask_u16  = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin));
+    }
+    else
+    {
+        auto pmax = calculate_max(vec_res_value);
+        mask_u16  = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax));
+    }
+
+    // Widen vectors
+    auto wide_u32_1 =
+        wrapper::vorr(vshll_n_u16(wrapper::vgetlow(mask_u16), 8), wrapper::vmovl(wrapper::vgetlow(mask_u16)));
+    auto wide_u32_2 =
+        wrapper::vorr(vshll_n_u16(wrapper::vgethigh(mask_u16), 8), wrapper::vmovl(wrapper::vgethigh(mask_u16)));
+    res_idx_mask.val[0] = wrapper::vand(vec_res_idx.val[0], wide_u32_1);
+    res_idx_mask.val[1] = wrapper::vand(vec_res_idx.val[1], wide_u32_2);
+    res_idx_mask.val[0] = wrapper::vadd(res_idx_mask.val[0], mask_ones);
+    res_idx_mask.val[1] = wrapper::vadd(res_idx_mask.val[1], mask_ones);
+
+    uint32_t res  = 0xFFFFFFFF;
+    uint32_t iter = 0;
+    do
+    {
+        auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask.val[iter]), wrapper::vgetlow(res_idx_mask.val[iter]));
+        pmin      = wrapper::vpmin(pmin, pmin);
+        res       = std::min(wrapper::vgetlane(pmin, 0), res);
+        iter++;
+    } while (iter < 2);
+
+    return (res - 0xFFFFFFFF);
+}
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+
 template <class F>
 class Reducer
 {
@@ -933,6 +1017,12 @@
                     if (op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX)
                     {
                         wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + x, vec_res_idx.val[0]);
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+                        if (std::is_same<T, float16_t>::value)
+                        {
+                            wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + x + 4, vec_res_idx.val[1]);
+                        }
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
                     }
                     else
                     {
diff --git a/src/cpu/kernels/reduction_layer/generic/neon/impl_fp16.h b/src/cpu/kernels/reduction_layer/generic/neon/impl_fp16.h
deleted file mode 100644
index c7ca36d..0000000
--- a/src/cpu/kernels/reduction_layer/generic/neon/impl_fp16.h
+++ /dev/null
@@ -1,718 +0,0 @@
-/*
- * Copyright (c) 2024 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-#ifndef ACL_SRC_CPU_KERNELS_REDUCTION_LAYER_GENERIC_NEON_IMPL_FP16_H
-#define ACL_SRC_CPU_KERNELS_REDUCTION_LAYER_GENERIC_NEON_IMPL_FP16_H
-
-#include "arm_compute/core/Coordinates.h"
-#include "arm_compute/core/Helpers.h"
-#include "arm_compute/core/ITensor.h"
-#include "arm_compute/core/TensorInfo.h"
-
-#include "src/core/NEON/NEMath.h"
-#include "src/core/NEON/wrapper/wrapper.h"
-#include "support/SaturateCast.h"
-
-#include <arm_neon.h>
-
-namespace arm_compute
-{
-// Helper function that calls vqmovun/vqmvn, vcombine and vstore, allows templating of RedOpYZW_quantized
-void combine_and_store(int16x8_t t1, int16x8_t t2, Iterator &output, int offset = 0)
-{
-    auto res = wrapper::vcombine(wrapper::vqmovn(t1), wrapper::vqmovn(t2));
-    wrapper::vstore(reinterpret_cast<int8_t *>(output.ptr() + offset), res);
-}
-
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-uint32x4x4_t
-calculate_index(uint32_t idx, float16x8_t a, float16x8_t b, uint32x4x4_t c, ReductionOperation op, int axis)
-{
-    uint32x4x2_t mask{0};
-    uint16x8_t   mask_u16{0};
-    if (op == ReductionOperation::ARG_IDX_MIN)
-    {
-        mask_u16 = wrapper::vcgt(b, a);
-    }
-    else
-    {
-        mask_u16 = wrapper::vclt(b, a);
-    }
-    mask.val[0]          = wrapper::vmovl(wrapper::vgetlow(mask_u16));
-    mask.val[1]          = wrapper::vmovl(wrapper::vgethigh(mask_u16));
-    uint32x4x2_t vec_idx = {{{idx + 0, idx + 1, idx + 2, idx + 3}, {idx + 4, idx + 5, idx + 6, idx + 7}}};
-    if (axis != 0)
-    {
-        vec_idx.val[0] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
-        vec_idx.val[1] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
-    }
-    uint32x4x4_t res = {wrapper::vbsl(mask.val[0], vec_idx.val[0], c.val[0]),
-                        wrapper::vbsl(mask.val[1], vec_idx.val[1], c.val[1]), 0, 0};
-
-    return res;
-}
-
-// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value.
-inline float16x4_t calculate_min(float16x8_t in)
-{
-    auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in));
-    pmin      = wrapper::vpmin(pmin, pmin);
-    return wrapper::vpmin(pmin, pmin);
-}
-// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value.
-inline float16x4_t calculate_max(float16x8_t in)
-{
-    auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in));
-    pmax      = wrapper::vpmax(pmax, pmax);
-    return wrapper::vpmax(pmax, pmax);
-}
-
-uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, float16x8_t vec_res_value, ReductionOperation op)
-{
-    uint32x4x2_t res_idx_mask{0};
-    uint32x4_t   mask_ones = vdupq_n_u32(0xFFFFFFFF);
-    uint16x8_t   mask_u16;
-    if (op == ReductionOperation::ARG_IDX_MIN)
-    {
-        auto pmin = calculate_min(vec_res_value);
-        mask_u16  = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin));
-    }
-    else
-    {
-        auto pmax = calculate_max(vec_res_value);
-        mask_u16  = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax));
-    }
-
-    // Widen vectors
-    auto wide_u32_1 =
-        wrapper::vorr(vshll_n_u16(wrapper::vgetlow(mask_u16), 8), wrapper::vmovl(wrapper::vgetlow(mask_u16)));
-    auto wide_u32_2 =
-        wrapper::vorr(vshll_n_u16(wrapper::vgethigh(mask_u16), 8), wrapper::vmovl(wrapper::vgethigh(mask_u16)));
-    res_idx_mask.val[0] = wrapper::vand(vec_res_idx.val[0], wide_u32_1);
-    res_idx_mask.val[1] = wrapper::vand(vec_res_idx.val[1], wide_u32_2);
-    res_idx_mask.val[0] = wrapper::vadd(res_idx_mask.val[0], mask_ones);
-    res_idx_mask.val[1] = wrapper::vadd(res_idx_mask.val[1], mask_ones);
-
-    uint32_t res  = 0xFFFFFFFF;
-    uint32_t iter = 0;
-    do
-    {
-        auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask.val[iter]), wrapper::vgetlow(res_idx_mask.val[iter]));
-        pmin      = wrapper::vpmin(pmin, pmin);
-        res       = std::min(wrapper::vgetlane(pmin, 0), res);
-        iter++;
-    } while (iter < 2);
-
-    return (res - 0xFFFFFFFF);
-}
-#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-
-template <class F>
-class Reducer
-{
-public:
-    static void reduceX(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
-    {
-        // Set out window
-        Window out_window(window);
-        out_window.set(Window::DimX, Window::Dimension(0, 1, 1));
-
-        f(window, out_window, input, output, op);
-    }
-    static void reduceY(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
-    {
-        // Set in window
-        Window in_window(window);
-        Window out_window(window);
-
-        in_window.set(Window::DimY, Window::Dimension(0, 1, 1));
-        out_window.set(Window::DimY, Window::Dimension(0, output->info()->dimension(1), output->info()->dimension(1)));
-
-        f(in_window, out_window, input, output, 1, op);
-    }
-    static void reduceZ(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
-    {
-        // Set in window
-        Window in_window(window);
-        Window out_window(window);
-
-        in_window.set(Window::DimZ, Window::Dimension(0, 1, 1));
-        out_window.set(Window::DimZ, Window::Dimension(0, output->info()->dimension(2), output->info()->dimension(2)));
-
-        f(in_window, out_window, input, output, 2, op);
-    }
-    static void reduceW(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
-    {
-        // Set in/out window
-        Window in_window(window);
-        Window out_window(window);
-
-        in_window.set(3, Window::Dimension(0, 1, 1));
-        out_window.set(3, Window::Dimension(0, 1, 1));
-
-        f(in_window, out_window, input, output, 3, op);
-    }
-};
-
-template <typename T, int S>
-struct RedOpX
-{
-    /** SIMD vector tag type. */
-    using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
-
-    inline void operator()(
-        const Window &in_window, Window &out_window, const ITensor *in, ITensor *out, const ReductionOperation op)
-    {
-        const size_t input_dim_0    = in->info()->dimension(0);
-        const int    window_step_x  = 16 / sizeof(T);
-        const auto   window_start_x = static_cast<int>(in_window.x().start());
-        const auto   window_end_x   = static_cast<int>(in_window.x().end());
-
-        Window in_win_no_pad = in_window;
-        in_win_no_pad.set(Window::DimX, Window::Dimension(0, 1, 1));
-
-        Iterator input(in, in_win_no_pad);
-        Iterator output(out, out_window);
-
-        execute_window_loop(
-            in_win_no_pad,
-            [&](const Coordinates &)
-            {
-                const auto input_ptr = reinterpret_cast<const T *>(input.ptr());
-
-                auto init_res_value = static_cast<T>(0.f);
-                switch (op)
-                {
-                    case ReductionOperation::ARG_IDX_MAX:
-                    case ReductionOperation::ARG_IDX_MIN:
-                    case ReductionOperation::MIN:
-                    case ReductionOperation::MAX:
-                    {
-                        init_res_value = static_cast<T>(*input_ptr);
-                        break;
-                    }
-                    case ReductionOperation::PROD:
-                    {
-                        init_res_value = static_cast<T>(1.f);
-                        break;
-                    }
-                    default:
-                        break;
-                }
-                auto         vec_res_value = wrapper::vdup_n(init_res_value, ExactTagType{});
-                uint32x4x4_t vec_res_idx{{0}};
-
-                // Compute window_step_x elements per iteration
-                int x = window_start_x;
-                for (; x <= (window_end_x - window_step_x); x += window_step_x)
-                {
-                    const auto vec_elements = wrapper::vloadq(input_ptr + x);
-                    switch (op)
-                    {
-                        case ReductionOperation::SUM_SQUARE:
-                            vec_res_value = wrapper::vadd(wrapper::vmul(vec_elements, vec_elements), vec_res_value);
-                            break;
-                        case ReductionOperation::MEAN_SUM:
-                        case ReductionOperation::SUM:
-                            vec_res_value = wrapper::vadd(vec_elements, vec_res_value);
-                            break;
-                        case ReductionOperation::PROD:
-                            vec_res_value = wrapper::vmul(vec_elements, vec_res_value);
-                            break;
-                        case ReductionOperation::ARG_IDX_MIN:
-                        {
-                            auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
-                            vec_res_idx   = calculate_index(x, temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
-                            vec_res_value = temp_vec_res_value;
-                            break;
-                        }
-                        case ReductionOperation::ARG_IDX_MAX:
-                        {
-                            auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
-                            vec_res_idx   = calculate_index(x, temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
-                            vec_res_value = temp_vec_res_value;
-                            break;
-                        }
-                        case ReductionOperation::MIN:
-                        {
-                            vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
-                            break;
-                        }
-                        case ReductionOperation::MAX:
-                        {
-                            vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
-                            break;
-                        }
-                        default:
-                            ARM_COMPUTE_ERROR("Not supported");
-                    }
-                }
-
-                switch (op)
-                {
-                    case ReductionOperation::SUM:
-                    case ReductionOperation::MEAN_SUM:
-                    case ReductionOperation::SUM_SQUARE:
-                    {
-#ifdef ARM_COMPUTE_DEBUG_ENABLED
-                        auto res = static_cast<T>(0.f);
-                        for (int i = 0; i < S; ++i)
-                        {
-                            res += wrapper::vgetlane(vec_res_value, i);
-                        }
-#else  // ARM_COMPUTE_DEBUG_ENABLED
-                        auto carry_res =
-                            wrapper::vpadd(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
-                        for (int i = 0; i < S / 4; ++i)
-                        {
-                            carry_res = wrapper::vpadd(carry_res, carry_res);
-                        }
-                        auto res = wrapper::vgetlane(carry_res, 0);
-#endif // ARM_COMPUTE_DEBUG_ENABLED
-                        if (op == ReductionOperation::SUM_SQUARE)
-                        {
-                            // Compute left-over elements
-                            for (; x < window_end_x; ++x)
-                            {
-                                res += (*(input_ptr + x)) * (*(input_ptr + x));
-                            }
-                        }
-                        else
-                        {
-                            // Compute left-over elements
-                            for (; x < window_end_x; ++x)
-                            {
-                                res += *(input_ptr + x);
-                            }
-                        }
-
-                        if (op == ReductionOperation::MEAN_SUM)
-                        {
-                            res /= input_dim_0;
-                        }
-
-                        *(reinterpret_cast<T *>(output.ptr())) = res;
-                        break;
-                    }
-                    case ReductionOperation::PROD:
-                    {
-                        auto carry_res =
-                            wrapper::vmul(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
-                        T res = 1;
-                        for (int i = 0; i < S / 2; ++i)
-                        {
-                            res *= wrapper::vgetlane(carry_res, i);
-                        }
-
-                        // Compute left-over elements
-                        for (; x < window_end_x; ++x)
-                        {
-                            res *= *(input_ptr + x);
-                        }
-
-                        *(reinterpret_cast<T *>(output.ptr())) = res;
-                        break;
-                    }
-                    case ReductionOperation::ARG_IDX_MIN:
-                    {
-                        auto idx = calculate_vector_index(vec_res_idx, vec_res_value, op);
-                        auto res = static_cast<T>(wrapper::vgetlane(calculate_min(vec_res_value), 0));
-
-                        // Compute left-over elements
-                        for (; x < window_end_x; ++x)
-                        {
-                            if (*(input_ptr + x) < res)
-                            {
-                                idx = x;
-                                res = *(input_ptr + x);
-                            }
-                        }
-                        *(reinterpret_cast<uint32_t *>(output.ptr())) = idx;
-                        break;
-                    }
-                    case ReductionOperation::ARG_IDX_MAX:
-                    {
-                        auto idx = calculate_vector_index(vec_res_idx, vec_res_value, op);
-                        auto res = static_cast<T>(wrapper::vgetlane(calculate_max(vec_res_value), 0));
-
-                        // Compute left-over elements
-                        for (; x < window_end_x; ++x)
-                        {
-                            if (*(input_ptr + x) > res)
-                            {
-                                idx = x;
-                                res = *(input_ptr + x);
-                            }
-                        }
-                        *(reinterpret_cast<uint32_t *>(output.ptr())) = idx;
-                        break;
-                    }
-                    case ReductionOperation::MIN:
-                    {
-                        auto res = static_cast<T>(wrapper::vgetlane(calculate_min(vec_res_value), 0));
-
-                        // Compute left-over elements
-                        for (; x < window_end_x; ++x)
-                        {
-                            res = *(input_ptr + x) < res ? *(input_ptr + x) : res;
-                        }
-                        *(reinterpret_cast<T *>(output.ptr())) = res;
-                        break;
-                    }
-                    case ReductionOperation::MAX:
-                    {
-                        auto res = static_cast<T>(wrapper::vgetlane(calculate_max(vec_res_value), 0));
-
-                        // Compute left-over elements
-                        for (; x < window_end_x; ++x)
-                        {
-                            res = *(input_ptr + x) > res ? *(input_ptr + x) : res;
-                        }
-                        *(reinterpret_cast<T *>(output.ptr())) = res;
-                        break;
-                    }
-                    default:
-                        ARM_COMPUTE_ERROR("Not supported");
-                }
-            },
-            input, output);
-    }
-};
-
-template <typename T, int S>
-struct RedOpYZW
-{
-    /** SIMD vector tag type. */
-    using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
-    using neon_vector  = typename wrapper::traits::neon_vector<T, S>::type;
-
-    inline void operator()(const Window            &in_window,
-                           Window                  &out_window,
-                           const ITensor           *in,
-                           ITensor                 *out,
-                           int                      axis,
-                           const ReductionOperation op)
-    {
-        const TensorInfo in_info            = *(in->info());
-        const int        window_step_x      = 16 / sizeof(T);
-        const auto       window_start_x_tmp = static_cast<int>(in_window.x().start());
-        const auto       window_end_x_tmp   = static_cast<int>(in_window.x().end());
-        // As it split over x-axis, need to set the correct spiltted window start and end.
-        const auto window_start_x = static_cast<int>(0);
-        const auto window_end_x   = static_cast<int>(in_window.shape().x());
-
-        Window in_win_no_pad = in_window;
-        in_win_no_pad.set(Window::DimX, Window::Dimension(window_start_x_tmp, window_end_x_tmp, in_window.shape().x()));
-        Window out_win_no_pad = out_window;
-        out_win_no_pad.set(Window::DimX,
-                           Window::Dimension(window_start_x_tmp, window_end_x_tmp, out_window.shape().x()));
-
-        Iterator input(in, in_win_no_pad);
-        Iterator output(out, out_win_no_pad);
-
-        execute_window_loop(
-            in_win_no_pad,
-            [&](const Coordinates &)
-            {
-                const auto input_ptr = reinterpret_cast<T *>(input.ptr());
-
-                // Compute window_step_x elements per iteration
-                int x = window_start_x;
-                for (; x <= (window_end_x - window_step_x); x += window_step_x)
-                {
-                    neon_vector vec_res_value = {0};
-                    switch (op)
-                    {
-                        case ReductionOperation::ARG_IDX_MAX:
-                        case ReductionOperation::ARG_IDX_MIN:
-                        case ReductionOperation::MIN:
-                        case ReductionOperation::MAX:
-                        {
-                            vec_res_value = wrapper::vloadq(input_ptr + x);
-                            break;
-                        }
-                        case ReductionOperation::PROD:
-                        {
-                            vec_res_value = wrapper::vdup_n(static_cast<T>(1.f), ExactTagType{});
-                            break;
-                        }
-                        default:
-                        {
-                            vec_res_value = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
-                            break;
-                        }
-                    }
-                    uint32x4x4_t vec_res_idx{{0}};
-
-                    for (unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
-                    {
-                        const T *in_ptr =
-                            reinterpret_cast<T *>(input.ptr() + x * sizeof(T) + in_info.strides_in_bytes()[axis] * dim);
-                        const auto vec_elements = wrapper::vloadq(in_ptr);
-                        switch (op)
-                        {
-                            case ReductionOperation::SUM:
-                            case ReductionOperation::MEAN_SUM:
-                                vec_res_value = wrapper::vadd(vec_elements, vec_res_value);
-                                break;
-                            case ReductionOperation::SUM_SQUARE:
-                                vec_res_value = wrapper::vadd(wrapper::vmul(vec_elements, vec_elements), vec_res_value);
-                                break;
-                            case ReductionOperation::PROD:
-                                vec_res_value = wrapper::vmul(vec_elements, vec_res_value);
-                                break;
-                            case ReductionOperation::ARG_IDX_MIN:
-                            {
-                                auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
-                                vec_res_idx =
-                                    calculate_index(dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
-                                vec_res_value = temp_vec_res_value;
-                                break;
-                            }
-                            case ReductionOperation::ARG_IDX_MAX:
-                            {
-                                auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
-                                vec_res_idx =
-                                    calculate_index(dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
-                                vec_res_value = temp_vec_res_value;
-                                break;
-                            }
-                            case ReductionOperation::MIN:
-                            {
-                                vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
-                                break;
-                            }
-                            case ReductionOperation::MAX:
-                            {
-                                vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
-                                break;
-                            }
-                            default:
-                                ARM_COMPUTE_ERROR("Not supported");
-                        }
-                    }
-
-                    if (op == ReductionOperation::MEAN_SUM)
-                    {
-                        auto vec_width_inv =
-                            wrapper::vinv(wrapper::vdup_n(static_cast<T>(in_info.dimension(axis)), ExactTagType{}));
-                        vec_res_value = wrapper::vmul(vec_res_value, vec_width_inv);
-                    }
-
-                    if (op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX)
-                    {
-                        wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + x, vec_res_idx.val[0]);
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-                        if (std::is_same<T, float16_t>::value)
-                        {
-                            wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + x + 4, vec_res_idx.val[1]);
-                        }
-#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-                    }
-                    else
-                    {
-                        wrapper::vstore(reinterpret_cast<T *>(output.ptr() + x * sizeof(T)), vec_res_value);
-                    }
-                }
-
-                // Compute left-over elements
-                for (; x < window_end_x; ++x)
-                {
-                    auto res_value = 0.f;
-                    switch (op)
-                    {
-                        case ReductionOperation::ARG_IDX_MAX:
-                        case ReductionOperation::ARG_IDX_MIN:
-                        case ReductionOperation::MIN:
-                        case ReductionOperation::MAX:
-                        {
-                            res_value = *(input_ptr + x);
-                            break;
-                        }
-                        case ReductionOperation::PROD:
-                        {
-                            res_value = static_cast<T>(1.f);
-                            break;
-                        }
-                        default:
-                        {
-                            res_value = static_cast<T>(0.f);
-                            break;
-                        }
-                    }
-
-                    uint32_t res_idx = 0;
-                    for (unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
-                    {
-                        const T *in_ptr =
-                            reinterpret_cast<T *>(input.ptr() + x * sizeof(T) + in_info.strides_in_bytes()[axis] * dim);
-
-                        switch (op)
-                        {
-                            case ReductionOperation::SUM:
-                            case ReductionOperation::MEAN_SUM:
-                                res_value += *in_ptr;
-                                break;
-                            case ReductionOperation::SUM_SQUARE:
-                                res_value += *in_ptr * *in_ptr;
-                                break;
-                            case ReductionOperation::PROD:
-                                res_value *= *in_ptr;
-                                break;
-                            case ReductionOperation::ARG_IDX_MIN:
-                            {
-                                if (*in_ptr < res_value)
-                                {
-                                    res_value = *in_ptr;
-                                    res_idx   = dim;
-                                }
-                                break;
-                            }
-                            case ReductionOperation::ARG_IDX_MAX:
-                            {
-                                if (*in_ptr > res_value)
-                                {
-                                    res_value = *in_ptr;
-                                    res_idx   = dim;
-                                }
-                                break;
-                            }
-                            case ReductionOperation::MIN:
-                            {
-                                res_value = *in_ptr < res_value ? *in_ptr : res_value;
-                                break;
-                            }
-                            case ReductionOperation::MAX:
-                            {
-                                res_value = *in_ptr > res_value ? *in_ptr : res_value;
-                                break;
-                            }
-                            default:
-                                ARM_COMPUTE_ERROR("Not supported");
-                        }
-                    }
-
-                    if (op == ReductionOperation::MEAN_SUM)
-                    {
-                        res_value /= in_info.dimension(axis);
-                    }
-
-                    if (op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX)
-                    {
-                        *(reinterpret_cast<uint32_t *>(output.ptr()) + x) = res_idx;
-                    }
-                    else
-                    {
-                        *(reinterpret_cast<T *>(output.ptr() + x * sizeof(T))) = res_value;
-                    }
-                }
-            },
-            input, output);
-    }
-};
-
-template <typename T, int S, int axis, ReductionOperation op>
-struct RedOpYZW_complex
-{
-    /** SIMD vector tag type. */
-    using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
-    using neon_vector  = typename wrapper::traits::neon_vector<T, S>::type;
-
-    inline void operator()(
-        const Window &in_window, Window &out_window, const ITensor *in, ITensor *out, int, const ReductionOperation)
-    {
-        ARM_COMPUTE_ERROR_ON(axis != 2);
-        ARM_COMPUTE_ERROR_ON(op != ReductionOperation::SUM);
-
-        const TensorInfo in_info            = *(in->info());
-        const size_t     stride_z           = in_info.strides_in_bytes()[axis];
-        const int        window_step_x      = 16 / sizeof(T);
-        const auto       window_start_x_tmp = static_cast<int>(in_window.x().start());
-        const auto       window_end_x_tmp   = static_cast<int>(in_window.x().end());
-        // As it split over x-axis, need to set the correct spiltted window start and end.
-        const auto window_start_x = static_cast<int>(0);
-        const auto window_end_x   = static_cast<int>(in_window.shape().x());
-
-        Window in_win_no_pad = in_window;
-        in_win_no_pad.set(Window::DimX, Window::Dimension(window_start_x_tmp, window_end_x_tmp, in_window.shape().x()));
-        Window out_win_no_pad = out_window;
-        out_win_no_pad.set(Window::DimX,
-                           Window::Dimension(window_start_x_tmp, window_end_x_tmp, out_window.shape().x()));
-
-        Iterator input(in, in_win_no_pad);
-        Iterator output(out, out_win_no_pad);
-
-        execute_window_loop(
-            in_win_no_pad,
-            [&](const Coordinates &)
-            {
-                // Compute window_step_x elements per iteration
-                int x = window_start_x;
-                for (; x <= (window_end_x - window_step_x); x += window_step_x)
-                {
-                    neon_vector vec_res_value_0 = {0};
-                    neon_vector vec_res_value_1 = {0};
-
-                    vec_res_value_0 = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
-                    vec_res_value_1 = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
-
-                    T *out_ptr = reinterpret_cast<T *>(output.ptr() + 2 * x * sizeof(T));
-                    for (unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
-                    {
-                        T *in_ptr_0 = reinterpret_cast<T *>(input.ptr() + 2 * x * sizeof(T) + stride_z * dim);
-                        T *in_ptr_1 = reinterpret_cast<T *>(input.ptr() + 2 * x * sizeof(T) + 16 + stride_z * dim);
-
-                        const auto vec_elements_0 = wrapper::vloadq(in_ptr_0);
-                        const auto vec_elements_1 = wrapper::vloadq(in_ptr_1);
-
-                        vec_res_value_0 = wrapper::vadd(vec_elements_0, vec_res_value_0);
-                        vec_res_value_1 = wrapper::vadd(vec_elements_1, vec_res_value_1);
-                    }
-
-                    wrapper::vstore(out_ptr, vec_res_value_0);
-                    wrapper::vstore(out_ptr + 4, vec_res_value_1);
-                }
-
-                // Compute left-over elements
-                for (; x < window_end_x; ++x)
-                {
-                    auto res_value_0 = 0.f;
-                    auto res_value_1 = 0.f;
-
-                    T *out_ptr = reinterpret_cast<T *>(output.ptr() + 2 * x * sizeof(T));
-                    for (unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
-                    {
-                        T *in_ptr = reinterpret_cast<T *>(input.ptr() + 2 * x * sizeof(T) + stride_z * dim);
-                        res_value_0 += *in_ptr;
-                        res_value_1 += *(in_ptr + 1);
-                    }
-                    *out_ptr       = res_value_0;
-                    *(out_ptr + 1) = res_value_1;
-                }
-            },
-            input, output);
-    }
-};
-
-} // namespace arm_compute
-#endif // ACL_SRC_CPU_KERNELS_REDUCTION_LAYER_GENERIC_NEON_IMPL_FP16_H