Decouple NEInstanceNormalizationLayerKernel

Resolves COMPMID-4620
Signed-off-by: Dana Zlotnik <dana.zlotnik@arm.com>
Change-Id: I22c285339840493c9cfd4c1abfbc3768ad4db824
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/6871
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Giorgio Arena <giorgio.arena@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/core/NEON/kernels/NEInstanceNormalizationLayerKernel.cpp b/src/core/NEON/kernels/NEInstanceNormalizationLayerKernel.cpp
index d33431a..7164140 100644
--- a/src/core/NEON/kernels/NEInstanceNormalizationLayerKernel.cpp
+++ b/src/core/NEON/kernels/NEInstanceNormalizationLayerKernel.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2019-2021 Arm Limited.
+ * Copyright (c) 2019-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -34,8 +34,10 @@
 #include "src/core/CPP/Validate.h"
 #include "src/core/NEON/NEMath.h"
 #include "src/core/NEON/wrapper/wrapper.h"
+#include "src/core/common/Registrars.h"
 #include "src/core/helpers/AutoConfiguration.h"
 #include "src/core/helpers/WindowHelpers.h"
+#include "src/cpu/kernels/instancenorm/list.h"
 
 #include <arm_neon.h>
 
@@ -43,137 +45,53 @@
 {
 namespace
 {
-template <typename InputType, typename AccType = InputType>
-void vector_float_sum(AccType &result, AccType &result_square, const InputType &inputs)
+struct InstanceNormSelectorData
 {
-    result        = wrapper::vadd(result, inputs);
-    result_square = wrapper::vadd(result_square, wrapper::vmul(inputs, inputs));
-}
+    DataType dt;
+};
 
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-template <>
-inline void vector_float_sum(float32x4_t &result, float32x4_t &result_square, const float16x8_t &inputs)
+using InstanceNormSelctorPtr = std::add_pointer<bool(const InstanceNormSelectorData &data)>::type;
+using InstanceNormUKernelPtr = std::add_pointer<void(ITensor *input, ITensor *output, float gamma, float beta, float epsilon, bool use_mixed_precision, const Window &window)>::type;
+
+struct InstanceNormKernel
 {
-    vector_float_sum(result, result_square, wrapper::vcvt<float>(wrapper::vgetlow(inputs)));
-    vector_float_sum(result, result_square, wrapper::vcvt<float>(wrapper::vgethigh(inputs)));
-}
-#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+    const char                  *name;
+    const InstanceNormSelctorPtr is_selected;
+    InstanceNormUKernelPtr       ukernel;
+};
 
-template <typename InputType, typename AccType = InputType>
-InputType vector_float_norm(const InputType &inputs, const AccType &vec_mean, const AccType &vec_multip, const AccType &vec_beta)
+static const InstanceNormKernel available_kernels[] =
 {
-    return wrapper::vadd(wrapper::vmul(wrapper::vsub(inputs, vec_mean), vec_multip), vec_beta);
-}
-
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-template <>
-inline float16x8_t vector_float_norm(const float16x8_t &inputs, const float32x4_t &vec_mean, const float32x4_t &vec_multip, const float32x4_t &vec_beta)
-{
-    const auto  input_low   = wrapper::vcvt<float>(wrapper::vgetlow(inputs));
-    const auto  input_high  = wrapper::vcvt<float>(wrapper::vgethigh(inputs));
-    const auto  result_low  = wrapper::vcvt<float16_t>(vector_float_norm(input_low, vec_mean, vec_multip, vec_beta));
-    const auto  result_high = wrapper::vcvt<float16_t>(vector_float_norm(input_high, vec_mean, vec_multip, vec_beta));
-    float16x8_t result      = wrapper::vcombine(result_low, result_high);
-
-    return result;
-}
-#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-
-template <typename T, typename AccType = T>
-void instance_normalization_nchw(ITensor *input, ITensor *output, float gamma, float beta, float epsilon, const Window &window)
-{
-    /** SIMD vector tag type. */
-    using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t<T, wrapper::traits::BitWidth::W128>;
-
-    // Clear X/Y dimensions on execution window as we handle the planes manually
-    Window win = window;
-    win.set(Window::DimX, Window::Dimension(0, 1, 1));
-    win.set(Window::DimY, Window::Dimension(0, 1, 1));
-
-    constexpr int      window_step_x  = 16 / sizeof(T);
-    const unsigned int elements_plane = input->info()->dimension(0) * output->info()->dimension(1);
-
-    Iterator input_it(input, win);
-    execute_window_loop(win, [&](const Coordinates & id)
     {
-        Window win_plane = window;
-        win_plane.set(Window::DimX, Window::Dimension(0, 1, 1));
-        win_plane.set(Window::DimZ, Window::Dimension(id[2], id[2] + 1, 1));
-        win_plane.set(3, Window::Dimension(id[3], id[3] + 1, 1));
-
-        Iterator input_plane_it(input, win_plane);
-        Iterator output_plane_it(output, win_plane);
-
-        auto sum_h_w         = static_cast<AccType>(0.f);
-        auto sum_squares_h_w = static_cast<AccType>(0.f);
-
-        execute_window_loop(win_plane, [&](const Coordinates &)
-        {
-            const auto input_ptr = reinterpret_cast<const T *>(input_plane_it.ptr());
-
-            auto vec_sum_h_w         = wrapper::vdup_n(static_cast<AccType>(0.f), ExactTagType{});
-            auto vec_sum_squares_h_w = wrapper::vdup_n(static_cast<AccType>(0.f), ExactTagType{});
-
-            // Compute S elements per iteration
-            int x = window.x().start();
-            for(; x <= (window.x().end() - window_step_x); x += window_step_x)
-            {
-                auto vec_input_val = wrapper::vloadq(input_ptr + x);
-                vector_float_sum(vec_sum_h_w, vec_sum_squares_h_w, vec_input_val);
-            }
-
-            auto vec2_sum_h_w         = wrapper::vpadd(wrapper::vgethigh(vec_sum_h_w), wrapper::vgetlow(vec_sum_h_w));
-            auto vec2_sum_squares_h_w = wrapper::vpadd(wrapper::vgethigh(vec_sum_squares_h_w), wrapper::vgetlow(vec_sum_squares_h_w));
-
-            vec2_sum_h_w         = wrapper::vpadd(vec2_sum_h_w, vec2_sum_h_w);
-            vec2_sum_squares_h_w = wrapper::vpadd(vec2_sum_squares_h_w, vec2_sum_squares_h_w);
-
-            sum_h_w += wrapper::vgetlane(vec2_sum_h_w, 0);
-            sum_squares_h_w += wrapper::vgetlane(vec2_sum_squares_h_w, 0);
-
-            // Compute left-over elements
-            for(; x < window.x().end(); ++x)
-            {
-                const auto value = static_cast<AccType>(*(input_ptr + x));
-                sum_h_w += value;
-                sum_squares_h_w += value * value;
-            }
-        },
-        input_plane_it, output_plane_it);
-
-        const auto mean_h_w = sum_h_w / elements_plane;
-        const auto var_h_w  = sum_squares_h_w / elements_plane - mean_h_w * mean_h_w;
-
-        const auto multip_h_w     = gamma / std::sqrt(var_h_w + epsilon);
-        const auto vec_mean_h_w   = wrapper::vdup_n(static_cast<AccType>(mean_h_w), ExactTagType{});
-        const auto vec_multip_h_w = wrapper::vdup_n(static_cast<AccType>(multip_h_w), ExactTagType{});
-        const auto vec_beta       = wrapper::vdup_n(static_cast<AccType>(beta), ExactTagType{});
-
-        execute_window_loop(win_plane, [&](const Coordinates &)
-        {
-            auto input_ptr  = reinterpret_cast<T *>(input_plane_it.ptr());
-            auto output_ptr = reinterpret_cast<T *>(output_plane_it.ptr());
-
-            // Compute S elements per iteration
-            int x = window.x().start();
-            //auto vec_val = wrapper::vdup_n(static_cast<T>(0.0f), ExactTagType{});
-            for(; x <= (window.x().end() - window_step_x); x += window_step_x)
-            {
-                const auto vec_val        = wrapper::vloadq(input_ptr + x);
-                const auto normalized_vec = vector_float_norm(vec_val, vec_mean_h_w, vec_multip_h_w, vec_beta);
-                wrapper::vstore(output_ptr + x, normalized_vec);
-            }
-
-            // Compute left-over elements
-            for(; x < window.x().end(); ++x)
-            {
-                const auto val    = static_cast<AccType>(*(input_ptr + x));
-                *(output_ptr + x) = static_cast<T>((val - mean_h_w) * multip_h_w + beta);
-            }
-        },
-        input_plane_it, output_plane_it);
+        "fp32_neon_instancenorm",
+        [](const InstanceNormSelectorData & data) { return data.dt == DataType::F32; },
+        REGISTER_FP32_NEON(arm_compute::cpu::neon_fp32_instancenorm)
     },
-    input_it);
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+    {
+        "fp16_neon_instancenorm",
+        [](const InstanceNormSelectorData & data) { return data.dt == DataType::F16; },
+        REGISTER_FP16_NEON(arm_compute::cpu::neon_fp16_instancenorm)
+    },
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+};
+
+/** Micro-kernel selector
+ *
+ * @param[in] data Selection data passed to help pick the appropriate micro-kernel
+ *
+ * @return A matching micro-kernel else nullptr
+ */
+const InstanceNormKernel *get_implementation(const InstanceNormSelectorData &data)
+{
+    for(const auto &uk : available_kernels)
+    {
+        if(uk.is_selected(data))
+        {
+            return &uk;
+        }
+    }
+    return nullptr;
 }
 
 Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, float gamma, float beta, float epsilon)
@@ -210,7 +128,7 @@
 } // namespace
 
 NEInstanceNormalizationLayerKernel::NEInstanceNormalizationLayerKernel()
-    : _func(nullptr), _input(nullptr), _output(nullptr), _gamma(1), _beta(0), _epsilon(1e-12)
+    : _input(nullptr), _output(nullptr), _gamma(1), _beta(0), _epsilon(1e-12)
 {
 }
 
@@ -227,28 +145,6 @@
 
     ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(_input->info(), _output->info(), _gamma, _beta, _epsilon));
 
-    if(_input->info()->data_type() == DataType::F32)
-    {
-        _func = &instance_normalization_nchw<float>;
-    }
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-    else if(_input->info()->data_type() == DataType::F16)
-    {
-        if(_use_mixed_precision)
-        {
-            _func = &instance_normalization_nchw<float16_t, float>;
-        }
-        else
-        {
-            _func = &instance_normalization_nchw<float16_t>;
-        }
-    }
-#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-    else
-    {
-        ARM_COMPUTE_ERROR("Unsupported data type");
-    }
-
     // Configure kernel window
     auto win_config = validate_and_configure_window(_input->info(), _output->info());
     ARM_COMPUTE_ERROR_THROW_ON(std::get<0>(win_config));
@@ -268,6 +164,10 @@
     ARM_COMPUTE_UNUSED(info);
     ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
     ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
-    (*_func)(_input, _output, _gamma, _beta, _epsilon, window);
+
+    const auto *uk = get_implementation(InstanceNormSelectorData{ _input->info()->data_type() });
+    ARM_COMPUTE_ERROR_ON(uk == nullptr || uk->ukernel == nullptr);
+
+    uk->ukernel(_input, _output, _gamma, _beta, _epsilon, _use_mixed_precision, window);
 }
 } // namespace arm_compute
diff --git a/src/core/NEON/kernels/NEInstanceNormalizationLayerKernel.h b/src/core/NEON/kernels/NEInstanceNormalizationLayerKernel.h
index 96c0119..f166ce2 100644
--- a/src/core/NEON/kernels/NEInstanceNormalizationLayerKernel.h
+++ b/src/core/NEON/kernels/NEInstanceNormalizationLayerKernel.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2019-2020 Arm Limited.
+ * Copyright (c) 2019-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -84,13 +84,12 @@
      */
     using NormalizationFunction = void(ITensor *input, ITensor *output, float gamma, float beta, float epsilon, const Window &window);
 
-    NormalizationFunction *_func;
-    ITensor               *_input;
-    ITensor               *_output;
-    float                  _gamma;
-    float                  _beta;
-    float                  _epsilon;
-    bool                   _use_mixed_precision{ true };
+    ITensor *_input;
+    ITensor *_output;
+    float    _gamma;
+    float    _beta;
+    float    _epsilon;
+    bool     _use_mixed_precision{ true };
 };
 } // namespace arm_compute
 #endif /*ARM_COMPUTE_NEINSTANCENORMALIZATIONLAYERKERNEL_H */
diff --git a/src/cpu/kernels/instancenorm/generic/neon/fp16.cpp b/src/cpu/kernels/instancenorm/generic/neon/fp16.cpp
new file mode 100644
index 0000000..e9fcc84
--- /dev/null
+++ b/src/cpu/kernels/instancenorm/generic/neon/fp16.cpp
@@ -0,0 +1,43 @@
+/*
+ * Copyright (c) 2022 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS)
+#include "src/cpu/kernels/instancenorm/generic/neon/impl.h"
+namespace arm_compute
+{
+namespace cpu
+{
+void neon_fp16_instancenorm(ITensor *input, ITensor *output, float gamma, float beta, float epsilon, bool use_mixed_precision, const Window &window)
+{
+    if(use_mixed_precision)
+    {
+        return instance_normalization_nchw<float16_t, float>(input, output, gamma, beta, epsilon, window);
+    }
+    else
+    {
+        return instance_normalization_nchw<float16_t>(input, output, gamma, beta, epsilon, window);
+    }
+}
+} // namespace cpu
+} // namespace arm_compute
+#endif /* defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS) */
diff --git a/src/cpu/kernels/instancenorm/generic/neon/fp32.cpp b/src/cpu/kernels/instancenorm/generic/neon/fp32.cpp
new file mode 100644
index 0000000..061dd95
--- /dev/null
+++ b/src/cpu/kernels/instancenorm/generic/neon/fp32.cpp
@@ -0,0 +1,35 @@
+/*
+ * Copyright (c) 2022 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "src/cpu/kernels/instancenorm/generic/neon/impl.h"
+namespace arm_compute
+{
+namespace cpu
+{
+void neon_fp32_instancenorm(ITensor *input, ITensor *output, float gamma, float beta, float epsilon, bool use_mixed_precision, const Window &window)
+{
+    ARM_COMPUTE_UNUSED(use_mixed_precision);
+    return instance_normalization_nchw<float>(input, output, gamma, beta, epsilon, window);
+}
+} // namespace cpu
+} // namespace arm_compute
diff --git a/src/cpu/kernels/instancenorm/generic/neon/impl.cpp b/src/cpu/kernels/instancenorm/generic/neon/impl.cpp
new file mode 100644
index 0000000..e35cf97
--- /dev/null
+++ b/src/cpu/kernels/instancenorm/generic/neon/impl.cpp
@@ -0,0 +1,172 @@
+/*
+ * Copyright (c) 2019-2022 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "src/cpu/kernels/instancenorm/generic/neon/impl.h"
+#include "src/core/NEON/wrapper/wrapper.h"
+
+namespace arm_compute
+{
+class ITensor;
+class Window;
+namespace cpu
+{
+template <typename InputType, typename AccType>
+void vector_float_sum(AccType &result, AccType &result_square, const InputType &inputs)
+{
+    result        = wrapper::vadd(result, inputs);
+    result_square = wrapper::vadd(result_square, wrapper::vmul(inputs, inputs));
+}
+
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+template <>
+inline void vector_float_sum(float32x4_t &result, float32x4_t &result_square, const float16x8_t &inputs)
+{
+    vector_float_sum(result, result_square, wrapper::vcvt<float>(wrapper::vgetlow(inputs)));
+    vector_float_sum(result, result_square, wrapper::vcvt<float>(wrapper::vgethigh(inputs)));
+}
+template <>
+inline float16x8_t vector_float_norm(const float16x8_t &inputs, const float32x4_t &vec_mean, const float32x4_t &vec_multip, const float32x4_t &vec_beta)
+{
+    const auto  input_low   = wrapper::vcvt<float>(wrapper::vgetlow(inputs));
+    const auto  input_high  = wrapper::vcvt<float>(wrapper::vgethigh(inputs));
+    const auto  result_low  = wrapper::vcvt<float16_t>(vector_float_norm(input_low, vec_mean, vec_multip, vec_beta));
+    const auto  result_high = wrapper::vcvt<float16_t>(vector_float_norm(input_high, vec_mean, vec_multip, vec_beta));
+    float16x8_t result      = wrapper::vcombine(result_low, result_high);
+
+    return result;
+}
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+
+template <typename InputType, typename AccType>
+InputType vector_float_norm(const InputType &inputs, const AccType &vec_mean, const AccType &vec_multip, const AccType &vec_beta)
+{
+    return wrapper::vadd(wrapper::vmul(wrapper::vsub(inputs, vec_mean), vec_multip), vec_beta);
+}
+
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+template <typename T, typename AccType>
+void instance_normalization_nchw(ITensor *input, ITensor *output, float gamma, float beta, float epsilon, const Window &window)
+{
+    /** SIMD vector tag type. */
+    using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t<T, wrapper::traits::BitWidth::W128>;
+
+    // Clear X/Y dimensions on execution window as we handle the planes manually
+    Window win = window;
+    win.set(Window::DimX, Window::Dimension(0, 1, 1));
+    win.set(Window::DimY, Window::Dimension(0, 1, 1));
+
+    constexpr int      window_step_x  = 16 / sizeof(T);
+    const unsigned int elements_plane = input->info()->dimension(0) * output->info()->dimension(1);
+
+    Iterator input_it(input, win);
+    execute_window_loop(win, [&](const Coordinates & id)
+    {
+        Window win_plane = window;
+        win_plane.set(Window::DimX, Window::Dimension(0, 1, 1));
+        win_plane.set(Window::DimZ, Window::Dimension(id[2], id[2] + 1, 1));
+        win_plane.set(3, Window::Dimension(id[3], id[3] + 1, 1));
+
+        Iterator input_plane_it(input, win_plane);
+        Iterator output_plane_it(output, win_plane);
+
+        auto sum_h_w         = static_cast<AccType>(0.f);
+        auto sum_squares_h_w = static_cast<AccType>(0.f);
+
+        execute_window_loop(win_plane, [&](const Coordinates &)
+        {
+            const auto input_ptr = reinterpret_cast<const T *>(input_plane_it.ptr());
+
+            auto vec_sum_h_w         = wrapper::vdup_n(static_cast<AccType>(0.f), ExactTagType{});
+            auto vec_sum_squares_h_w = wrapper::vdup_n(static_cast<AccType>(0.f), ExactTagType{});
+
+            // Compute S elements per iteration
+            int x = window.x().start();
+            for(; x <= (window.x().end() - window_step_x); x += window_step_x)
+            {
+                auto vec_input_val = wrapper::vloadq(input_ptr + x);
+                vector_float_sum(vec_sum_h_w, vec_sum_squares_h_w, vec_input_val);
+            }
+
+            auto vec2_sum_h_w         = wrapper::vpadd(wrapper::vgethigh(vec_sum_h_w), wrapper::vgetlow(vec_sum_h_w));
+            auto vec2_sum_squares_h_w = wrapper::vpadd(wrapper::vgethigh(vec_sum_squares_h_w), wrapper::vgetlow(vec_sum_squares_h_w));
+
+            vec2_sum_h_w         = wrapper::vpadd(vec2_sum_h_w, vec2_sum_h_w);
+            vec2_sum_squares_h_w = wrapper::vpadd(vec2_sum_squares_h_w, vec2_sum_squares_h_w);
+
+            sum_h_w += wrapper::vgetlane(vec2_sum_h_w, 0);
+            sum_squares_h_w += wrapper::vgetlane(vec2_sum_squares_h_w, 0);
+
+            // Compute left-over elements
+            for(; x < window.x().end(); ++x)
+            {
+                const auto value = static_cast<AccType>(*(input_ptr + x));
+                sum_h_w += value;
+                sum_squares_h_w += value * value;
+            }
+        },
+        input_plane_it, output_plane_it);
+
+        const auto mean_h_w = sum_h_w / elements_plane;
+        const auto var_h_w  = sum_squares_h_w / elements_plane - mean_h_w * mean_h_w;
+
+        const auto multip_h_w     = gamma / std::sqrt(var_h_w + epsilon);
+        const auto vec_mean_h_w   = wrapper::vdup_n(static_cast<AccType>(mean_h_w), ExactTagType{});
+        const auto vec_multip_h_w = wrapper::vdup_n(static_cast<AccType>(multip_h_w), ExactTagType{});
+        const auto vec_beta       = wrapper::vdup_n(static_cast<AccType>(beta), ExactTagType{});
+
+        execute_window_loop(win_plane, [&](const Coordinates &)
+        {
+            auto input_ptr  = reinterpret_cast<T *>(input_plane_it.ptr());
+            auto output_ptr = reinterpret_cast<T *>(output_plane_it.ptr());
+
+            // Compute S elements per iteration
+            int x = window.x().start();
+            //auto vec_val = wrapper::vdup_n(static_cast<T>(0.0f), ExactTagType{});
+            for(; x <= (window.x().end() - window_step_x); x += window_step_x)
+            {
+                const auto vec_val        = wrapper::vloadq(input_ptr + x);
+                const auto normalized_vec = vector_float_norm(vec_val, vec_mean_h_w, vec_multip_h_w, vec_beta);
+                wrapper::vstore(output_ptr + x, normalized_vec);
+            }
+
+            // Compute left-over elements
+            for(; x < window.x().end(); ++x)
+            {
+                const auto val    = static_cast<AccType>(*(input_ptr + x));
+                *(output_ptr + x) = static_cast<T>((val - mean_h_w) * multip_h_w + beta);
+            }
+        },
+        input_plane_it, output_plane_it);
+    },
+    input_it);
+}
+
+template void instance_normalization_nchw<float>(ITensor *input, ITensor *output, float gamma, float beta, float epsilon, const Window &window);
+#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS)
+template void instance_normalization_nchw<float16_t, float>(ITensor *input, ITensor *output, float gamma, float beta, float epsilon, const Window &window);
+template void instance_normalization_nchw<float16_t>(ITensor *input, ITensor *output, float gamma, float beta, float epsilon, const Window &window);
+#endif //defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS)
+} // namespace cpu
+} // namespace arm_compute
diff --git a/src/cpu/kernels/instancenorm/generic/neon/impl.h b/src/cpu/kernels/instancenorm/generic/neon/impl.h
new file mode 100644
index 0000000..fa4b4b6
--- /dev/null
+++ b/src/cpu/kernels/instancenorm/generic/neon/impl.h
@@ -0,0 +1,50 @@
+/*
+ * Copyright (c) 2022 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef SRC_CORE_SVE_KERNELS_INSTANCENORM_IMPL_H
+#define SRC_CORE_SVE_KERNELS_INSTANCENORM_IMPL_H
+#include "arm_compute/core/Helpers.h"
+namespace arm_compute
+{
+namespace cpu
+{
+template <typename T, typename AccType = T>
+void instance_normalization_nchw(ITensor *input, ITensor *output, float gamma, float beta, float epsilon, const Window &window);
+
+template <typename InputType, typename AccType = InputType>
+void vector_float_sum(AccType &result, AccType &result_square, const InputType &inputs);
+
+template <typename InputType, typename AccType = InputType>
+InputType vector_float_norm(const InputType &inputs, const AccType &vec_mean, const AccType &vec_multip, const AccType &vec_beta);
+
+#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS)
+template <>
+inline void vector_float_sum(float32x4_t &result, float32x4_t &result_square, const float16x8_t &inputs);
+
+template <>
+inline float16x8_t vector_float_norm(const float16x8_t &inputs, const float32x4_t &vec_mean, const float32x4_t &vec_multip, const float32x4_t &vec_beta);
+#endif //defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS)
+
+} // namespace cpu
+} // namespace arm_compute
+#endif //define SRC_CORE_SVE_KERNELS_INSTANCENORM_IMPL_H
diff --git a/src/cpu/kernels/instancenorm/list.h b/src/cpu/kernels/instancenorm/list.h
new file mode 100644
index 0000000..54f1d32
--- /dev/null
+++ b/src/cpu/kernels/instancenorm/list.h
@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) 2022 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef SRC_CORE_NEON_KERNELS_INSTANCENORM_LIST_H
+#define SRC_CORE_NEON_KERNELS_INSTANCENORM_LIST_H
+namespace arm_compute
+{
+namespace cpu
+{
+#define DECLARE_INSTANCENORM_KERNEL(func_name) \
+    void func_name(ITensor *input, ITensor *output, float gamma, float beta, float epsilon, bool use_mixed_precision, const Window &window)
+DECLARE_INSTANCENORM_KERNEL(neon_fp32_instancenorm);
+DECLARE_INSTANCENORM_KERNEL(neon_fp16_instancenorm);
+#undef DECLARE_INSTANCENORM_KERNEL
+} // namespace cpu
+} // namespace arm_compute
+#endif //SRC_CORE_NEON_KERNELS_INSTANCENORM_LIST_H