COMPMID-814: NEScale NHWC support

Change-Id: Ibf5c624a5c5482faa42eb02bc8abe9ae0d65b0d1
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/130608
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
diff --git a/src/core/Helpers.cpp b/src/core/Helpers.cpp
index c39922b..e336331 100644
--- a/src/core/Helpers.cpp
+++ b/src/core/Helpers.cpp
@@ -177,21 +177,25 @@
 ValidRegion arm_compute::calculate_valid_region_scale(const ITensorInfo &src_info, const TensorShape &dst_shape,
                                                       InterpolationPolicy interpolate_policy, SamplingPolicy sampling_policy, bool border_undefined)
 {
-    const float scale_x        = static_cast<float>(dst_shape[0]) / src_info.tensor_shape()[0];
-    const float scale_y        = static_cast<float>(dst_shape[1]) / src_info.tensor_shape()[1];
+    const DataLayout data_layout = src_info.data_layout();
+    const int        idx_width   = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
+    const int        idx_height  = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
+
+    const float scale_x        = static_cast<float>(dst_shape[idx_width]) / src_info.tensor_shape()[idx_width];
+    const float scale_y        = static_cast<float>(dst_shape[idx_height]) / src_info.tensor_shape()[idx_height];
     const float sampling_point = (sampling_policy == SamplingPolicy::CENTER) ? 0.5f : 0.0f;
 
     // Get input's valid region start and end points
-    const int valid_start_in_x = src_info.valid_region().anchor[0];
-    const int valid_start_in_y = src_info.valid_region().anchor[1];
-    const int valid_end_in_x   = src_info.valid_region().anchor[0] + src_info.valid_region().shape[0];
-    const int valid_end_in_y   = src_info.valid_region().anchor[1] + src_info.valid_region().shape[1];
+    const int valid_start_in_x = src_info.valid_region().anchor[idx_width];
+    const int valid_start_in_y = src_info.valid_region().anchor[idx_height];
+    const int valid_end_in_x   = src_info.valid_region().anchor[idx_width] + src_info.valid_region().shape[idx_width];
+    const int valid_end_in_y   = src_info.valid_region().anchor[idx_height] + src_info.valid_region().shape[idx_height];
 
     // Initialize output's valid region start and end points
     auto valid_start_out_x = static_cast<int>(valid_start_in_x * scale_x);
     auto valid_start_out_y = static_cast<int>(valid_start_in_y * scale_y);
-    auto valid_end_out_x   = std::min<int>(std::ceil(valid_end_in_x * scale_x), dst_shape[0]);
-    auto valid_end_out_y   = std::min<int>(std::ceil(valid_end_in_y * scale_y), dst_shape[1]);
+    auto valid_end_out_x   = std::min<int>(std::ceil(valid_end_in_x * scale_x), dst_shape[idx_width]);
+    auto valid_end_out_y   = std::min<int>(std::ceil(valid_end_in_y * scale_y), dst_shape[idx_height]);
 
     // Handle valid points in case of the bi-linear interpolation
     if(border_undefined)
@@ -237,11 +241,11 @@
     // Setup output valid region
     ValidRegion valid_region{ Coordinates(), dst_shape, src_info.tensor_shape().num_dimensions() };
 
-    valid_region.anchor.set(0, std::max(0, valid_start_out_x));
-    valid_region.anchor.set(1, std::max(0, valid_start_out_y));
+    valid_region.anchor.set(idx_width, std::max(0, valid_start_out_x));
+    valid_region.anchor.set(idx_height, std::max(0, valid_start_out_y));
 
-    valid_region.shape.set(0, std::min<size_t>(valid_end_out_x - valid_start_out_x, dst_shape[0]));
-    valid_region.shape.set(1, std::min<size_t>(valid_end_out_y - valid_start_out_y, dst_shape[1]));
+    valid_region.shape.set(idx_width, std::min<size_t>(valid_end_out_x - valid_start_out_x, dst_shape[idx_width]));
+    valid_region.shape.set(idx_height, std::min<size_t>(valid_end_out_y - valid_start_out_y, dst_shape[idx_height]));
 
     return valid_region;
 }
\ No newline at end of file
diff --git a/src/core/NEON/kernels/NEScaleKernel.cpp b/src/core/NEON/kernels/NEScaleKernel.cpp
index 852ec3e..311c807 100644
--- a/src/core/NEON/kernels/NEScaleKernel.cpp
+++ b/src/core/NEON/kernels/NEScaleKernel.cpp
@@ -28,28 +28,174 @@
 #include "arm_compute/core/Error.h"
 #include "arm_compute/core/Helpers.h"
 #include "arm_compute/core/ITensor.h"
+#include "arm_compute/core/NEON/wrapper/wrapper.h"
 #include "arm_compute/core/TensorInfo.h"
 #include "arm_compute/core/Validate.h"
 #include "arm_compute/core/Window.h"
+#include "arm_compute/core/utils/misc/Utility.h"
 
 #include <arm_neon.h>
 #include <cstddef>
 #include <cstdint>
 
-using namespace arm_compute;
+namespace arm_compute
+{
+namespace
+{
+Window configure_nchw(const ITensor *input, const ITensor *dx, const ITensor *dy, const ITensor *offsets, ITensor *output,
+                      InterpolationPolicy policy, bool border_undefined, SamplingPolicy sampling_policy, BorderSize border_size)
+{
+    constexpr unsigned int num_elems_processed_per_iteration = 16;
+
+    // Configure kernel window
+    Window win = calculate_max_window(*output->info(), Steps(num_elems_processed_per_iteration));
+
+    const ValidRegion &input_valid_region = input->info()->valid_region();
+
+    // Reads can occur within the valid region of the input
+    AccessWindowStatic input_access(input->info(), input_valid_region.anchor[0] - border_size.left,
+                                    input_valid_region.anchor[1] - border_size.top,
+                                    input_valid_region.anchor[0] + input_valid_region.shape[0] + border_size.right,
+                                    input_valid_region.anchor[1] + input_valid_region.shape[1] + border_size.bottom);
+    AccessWindowHorizontal offsets_access(offsets == nullptr ? nullptr : offsets->info(), 0,
+                                          num_elems_processed_per_iteration);
+    AccessWindowHorizontal dx_access(dx == nullptr ? nullptr : dx->info(), 0, num_elems_processed_per_iteration);
+    AccessWindowHorizontal dy_access(dy == nullptr ? nullptr : dy->info(), 0, num_elems_processed_per_iteration);
+    AccessWindowHorizontal output_access(output->info(), 0, num_elems_processed_per_iteration);
+
+    update_window_and_padding(win, input_access, offsets_access, dx_access, dy_access, output_access);
+
+    output_access.set_valid_region(win, calculate_valid_region_scale(*(input->info()), output->info()->tensor_shape(),
+                                                                     policy, sampling_policy, border_undefined));
+
+    return win;
+}
+Window configure_nhwc(const ITensor *input, ITensor *output,
+                      InterpolationPolicy policy, bool border_undefined, SamplingPolicy sampling_policy, BorderSize border_size)
+{
+    unsigned int num_elems_processed_per_iteration = (policy == InterpolationPolicy::NEAREST_NEIGHBOR) ? 16 / input->info()->element_size() : 1;
+
+    // Configure kernel window
+    Window win = calculate_max_window(*output->info(), Steps(num_elems_processed_per_iteration));
+
+    AccessWindowStatic input_access(input->info(), 0, -border_size.top,
+                                    ceil_to_multiple(input->info()->tensor_shape()[0], num_elems_processed_per_iteration),
+                                    input->info()->tensor_shape()[1]);
+    AccessWindowHorizontal output_access(output->info(), 0, num_elems_processed_per_iteration);
+
+    update_window_and_padding(win, input_access, output_access);
+    output->info()->set_valid_region(calculate_valid_region_scale(*(input->info()), output->info()->tensor_shape(),
+                                                                  policy, sampling_policy, border_undefined));
+
+    return win;
+}
+
+template <typename T>
+inline void scale_nearest_nhwc_core(const ITensor *input, const ITensor *offsets, ITensor *output,
+                                    float hr, Window window, const Window &win_in, size_t stride_w, size_t stride_h, size_t stride_c)
+{
+    Iterator in(input, win_in);
+    Iterator out(output, window);
+
+    const size_t offsets_stride = stride_w / sizeof(T);
+
+    execute_window_loop(window, [&](const Coordinates & id)
+    {
+        const auto offset     = *reinterpret_cast<const int32_t *>(offsets->ptr_to_element(Coordinates(id.y(), id.z())));
+        const int  in_yi      = (id.z() + 0.5f) * hr;
+        const int  offset_row = in_yi * stride_h + id.x() * stride_c;
+        wrapper::vstore(reinterpret_cast<T *>(out.ptr()),
+                        wrapper::vloadq(reinterpret_cast<const T *>(in.ptr() + offset * offsets_stride + offset_row)));
+    },
+    in, out);
+}
+
+template <typename T>
+inline void scale_bilinear_nhwc_core(const ITensor *input, const ITensor *offsets, const ITensor *dx, const ITensor *dy, ITensor *output,
+                                     float hr, Window window, const Window &win_in, size_t stride_w, size_t stride_h, size_t stride_c, BorderMode border_mode)
+{
+    Iterator in(input, win_in);
+    Iterator out(output, window);
+
+    const size_t stride_w_elems = stride_w / sizeof(T);
+    const size_t stride_h_elems = stride_h / sizeof(T);
+
+    const size_t input_width  = input->info()->dimension(1);
+    const size_t input_height = input->info()->dimension(2);
+
+    const T *border_area = reinterpret_cast<T *>(input->buffer() + input->info()->offset_first_element_in_bytes() - stride_w);
+
+    auto is_valid = [](int x, int low_x, int high_x, int y, int low_y, int high_y)
+    {
+        return !(x < low_x || x > high_x || y < low_y || y > high_y);
+    };
+
+    execute_window_loop(window, [&](const Coordinates & id)
+    {
+        const auto offset     = (*reinterpret_cast<const int32_t *>(offsets->ptr_to_element(Coordinates(id.y(), id.z())))) / sizeof(T);
+        const auto dx_scale   = *reinterpret_cast<const float *>(dx->ptr_to_element(Coordinates(id.y(), id.z())));
+        const auto dy_scale   = *reinterpret_cast<const float *>(dy->ptr_to_element(Coordinates(id.y(), id.z())));
+        const int  in_yi      = std::floor((id.z() + 0.5f) * hr - 0.5f);
+        const int  offset_row = in_yi * stride_h + id.x() * stride_c;
+        const T   *in_ptr     = reinterpret_cast<T *>(in.ptr() + offset * stride_w + offset_row);
+
+        T a00 = 0, a01 = 0, a10 = 0, a11 = 0;
+
+        if(border_mode == BorderMode::CONSTANT)
+        {
+            a00 = is_valid(offset, 0, input_width - 1, in_yi, 0, input_height - 1) ? *in_ptr : *border_area;
+            a01 = is_valid(offset + 1, 0, input_width - 1, in_yi, 0, input_height - 1) ? *(in_ptr + stride_w_elems) : *border_area;
+            a10 = is_valid(offset, 0, input_width - 1, in_yi + 1, 0, input_height - 1) ? *(in_ptr + stride_h_elems) : *border_area;
+            a11 = is_valid(offset + 1, 0, input_width - 1, in_yi + 1, 0, input_height - 1) ? *(in_ptr + stride_h_elems + stride_w_elems) : *border_area;
+        }
+        else if(border_mode == BorderMode::REPLICATE)
+        {
+            auto clamped_x  = utility::clamp<int>(offset, 0, input_width - 1);
+            auto clamped_x1 = utility::clamp<int>(offset + 1, 0, input_width - 1);
+            auto clamped_y  = utility::clamp<int>(in_yi, 0, input_height - 1);
+            auto clamped_y1 = utility::clamp<int>(in_yi + 1, 0, input_height - 1);
+
+            a00 = *reinterpret_cast<T *>(in.ptr() + clamped_x * stride_w + clamped_y * stride_h + id.x() * stride_c);
+            a01 = *reinterpret_cast<T *>(in.ptr() + clamped_x1 * stride_w + clamped_y * stride_h + id.x() * stride_c);
+            a10 = *reinterpret_cast<T *>(in.ptr() + clamped_x * stride_w + clamped_y1 * stride_h + id.x() * stride_c);
+            a11 = *reinterpret_cast<T *>(in.ptr() + clamped_x1 * stride_w + clamped_y1 * stride_h + id.x() * stride_c);
+        }
+        else
+        {
+            a00 = *in_ptr;
+            a01 = *(in_ptr + stride_w_elems);
+            a10 = *(in_ptr + stride_h_elems);
+            a11 = *(in_ptr + stride_h_elems + stride_w_elems);
+        }
+
+        // Perform interpolation
+        const float dx1 = 1.0f - dx_scale;
+        const float dy1 = 1.0f - dy_scale;
+
+        const float w1 = dx1 * dy1;
+        const float w2 = dx_scale * dy1;
+        const float w3 = dx1 * dy_scale;
+        const float w4 = dx_scale * dy_scale;
+
+        // Store result
+        *reinterpret_cast<T *>(out.ptr()) = static_cast<T>(a00 * w1 + a01 * w2 + a10 * w3 + a11 * w4);
+    },
+    in, out);
+}
+} // namespace
 
 NEScaleKernel::NEScaleKernel()
-    : _func(nullptr), _offsets(nullptr), _dx(nullptr), _dy(nullptr), _input(nullptr), _output(nullptr)
+    : _func(nullptr), _offsets(nullptr), _dx(nullptr), _dy(nullptr), _input(nullptr), _output(nullptr), _policy(), _border_size(1), _border_mode()
 {
 }
 
 BorderSize NEScaleKernel::border_size() const
 {
-    return BorderSize(1);
+    return _border_size;
 }
 
-void NEScaleKernel::configure(const ITensor *input, const ITensor *dx, const ITensor *dy, const ITensor *offsets, ITensor *output, InterpolationPolicy policy, bool border_undefined,
-                              SamplingPolicy sampling_policy)
+void NEScaleKernel::configure(const ITensor *input, const ITensor *dx, const ITensor *dy, const ITensor *offsets,
+                              ITensor *output, InterpolationPolicy policy, BorderMode border_mode, SamplingPolicy sampling_policy)
 {
     ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8, DataType::S16, DataType::F32);
     ARM_COMPUTE_ERROR_ON_NULLPTR(output);
@@ -70,35 +216,45 @@
         ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(dy, 1, DataType::F32);
     }
 
-    ARM_COMPUTE_ERROR_ON(output->info()->dimension(0) == 0);
-    ARM_COMPUTE_ERROR_ON(output->info()->dimension(1) == 0);
+    // Get data layout and width/height indices
+    const DataLayout data_layout = input->info()->data_layout();
+    const int        idx_width   = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
+    const int        idx_height  = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
 
-    for(size_t i = 2; i < Coordinates::num_max_dimensions; ++i)
+    ARM_COMPUTE_ERROR_ON(output->info()->dimension(idx_width) == 0);
+    ARM_COMPUTE_ERROR_ON(output->info()->dimension(idx_height) == 0);
+
+    _input       = input;
+    _output      = output;
+    _offsets     = offsets;
+    _dx          = dx;
+    _dy          = dy;
+    _policy      = policy;
+    _border_size = BorderSize(1);
+    _border_mode = border_mode;
+
+    // Compute the ratio between source width/height and destination width/height
+    const auto wr = static_cast<float>(input->info()->dimension(idx_width)) / static_cast<float>(output->info()->dimension(idx_width));
+    const auto hr = static_cast<float>(input->info()->dimension(idx_height)) / static_cast<float>(output->info()->dimension(idx_height));
+
+    // Add constant border only on top in case of NHWC layout
+    if(data_layout == DataLayout::NHWC)
     {
-        ARM_COMPUTE_ERROR_ON(input->info()->dimension(i) != output->info()->dimension(i));
+        _border_size = (border_mode == BorderMode::CONSTANT && policy == InterpolationPolicy::BILINEAR) ? BorderSize(1, 0, 0, 0) : BorderSize(0);
     }
 
-    _input   = input;
-    _output  = output;
-    _offsets = offsets;
-    _dx      = dx;
-    _dy      = dy;
-
-    /* Compute the ratio between source width/height and destination width/height */
-    const auto wr = static_cast<float>(input->info()->dimension(0)) / static_cast<float>(output->info()->dimension(0));
-    const auto hr = static_cast<float>(input->info()->dimension(1)) / static_cast<float>(output->info()->dimension(1));
-
-    /* Area interpolation behaves as Nearest Neighbour in case of up-sampling */
+    // Area interpolation behaves as Nearest Neighbour in case of up-sampling
     if(policy == InterpolationPolicy::AREA && wr <= 1.f && hr <= 1.f)
     {
         policy = InterpolationPolicy::NEAREST_NEIGHBOR;
     }
 
+    // Select interpolation function
     switch(policy)
     {
         case InterpolationPolicy::NEAREST_NEIGHBOR:
         {
-            _func = &NEScaleKernel::scale_nearest;
+            _func = (data_layout == DataLayout::NCHW) ? &NEScaleKernel::scale_nearest_nchw : &NEScaleKernel::scale_nhwc;
             break;
         }
         case InterpolationPolicy::BILINEAR:
@@ -106,51 +262,37 @@
             ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(_dx, 1, DataType::F32);
             ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(_dy, 1, DataType::F32);
 
-            _func = &NEScaleKernel::scale_bilinear;
+            _func = (data_layout == DataLayout::NCHW) ? &NEScaleKernel::scale_bilinear_nchw : &NEScaleKernel::scale_nhwc;
             break;
         }
         case InterpolationPolicy::AREA:
         {
-            _func = &NEScaleKernel::scale_area;
+            ARM_COMPUTE_ERROR_ON(data_layout != DataLayout::NCHW);
+
+            _func = &NEScaleKernel::scale_area_nchw;
             break;
         }
         default:
             ARM_COMPUTE_ERROR("Unsupported interpolation mode");
     }
 
-    constexpr unsigned int num_elems_processed_per_iteration = 16;
-
-    // Configure kernel window
-    Window win = calculate_max_window(*output->info(), Steps(num_elems_processed_per_iteration));
-
-    const ValidRegion &input_valid_region = input->info()->valid_region();
-
-    // Reads can occur within the valid region of the input
-    AccessWindowStatic input_access(input->info(),
-                                    input_valid_region.anchor[0] - border_size().left, input_valid_region.anchor[1] - border_size().top,
-                                    input_valid_region.anchor[0] + input_valid_region.shape[0] + border_size().right,
-                                    input_valid_region.anchor[1] + input_valid_region.shape[1] + border_size().bottom);
-    AccessWindowHorizontal offsets_access(offsets == nullptr ? nullptr : offsets->info(), 0, num_elems_processed_per_iteration);
-    AccessWindowHorizontal dx_access(dx == nullptr ? nullptr : dx->info(), 0, num_elems_processed_per_iteration);
-    AccessWindowHorizontal dy_access(dy == nullptr ? nullptr : dy->info(), 0, num_elems_processed_per_iteration);
-    AccessWindowHorizontal output_access(output->info(), 0, num_elems_processed_per_iteration);
-
-    update_window_and_padding(win,
-                              input_access,
-                              offsets_access,
-                              dx_access,
-                              dy_access,
-                              output_access);
-
-    output_access.set_valid_region(win, calculate_valid_region_scale(*(input->info()),
-                                                                     output->info()->tensor_shape(),
-                                                                     policy,
-                                                                     sampling_policy,
-                                                                     border_undefined));
+    // Configure window
+    Window win{};
+    switch(data_layout)
+    {
+        case DataLayout::NCHW:
+            win = configure_nchw(input, dx, dy, offsets, output, policy, border_mode == BorderMode::UNDEFINED, sampling_policy, border_size());
+            break;
+        case DataLayout::NHWC:
+            win = configure_nhwc(input, output, policy, border_mode == BorderMode::UNDEFINED, sampling_policy, border_size());
+            break;
+        default:
+            ARM_COMPUTE_ERROR("Unsupported data layout");
+    }
     INEKernel::configure(win);
 }
 
-void NEScaleKernel::scale_nearest(const Window &window)
+void NEScaleKernel::scale_nearest_nchw(const Window &window)
 {
     const size_t input_stride = _input->info()->strides_in_bytes()[1];
 
@@ -163,15 +305,16 @@
     win_in.set(Window::DimX, Window::Dimension(0, 0, 0));
     win_in.set(Window::DimY, Window::Dimension(0, 0, 0));
 
+    // Set offsets window
     Window win_off;
     win_off.set(Window::DimX, window[Window::DimX]);
     win_off.set(Window::DimY, window[Window::DimY]);
-
     for(size_t d = Window::DimZ; d < _offsets->info()->num_dimensions(); ++d)
     {
         win_off.set(d, Window::Dimension(0, 0, 0));
     }
 
+    // Create iterators
     Iterator in(_input, win_in);
     Iterator out(_output, window);
     Iterator offsets(_offsets, win_off);
@@ -304,7 +447,7 @@
     }
 }
 
-void NEScaleKernel::scale_bilinear(const Window &window)
+void NEScaleKernel::scale_bilinear_nchw(const Window &window)
 {
     ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(_input, 1, DataType::U8, DataType::S16, DataType::F32);
 
@@ -469,15 +612,16 @@
     }
 }
 
-void NEScaleKernel::scale_area(const Window &window)
+void NEScaleKernel::scale_area_nchw(const Window &window)
 {
     ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(_input, 1, DataType::U8);
 
-    // Don't increment in X and Y direction for the input tensor
+    // Don't increment in width/height/channels for the input tensor
     // A pointer to the start of this plane is needed as base for the precomputed offsets
     Window win_in(window);
     win_in.set(Window::DimX, Window::Dimension(0, 0, 0));
     win_in.set(Window::DimY, Window::Dimension(0, 0, 0));
+    win_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
 
     Iterator in(_input, win_in);
     Iterator out(_output, window);
@@ -517,6 +661,77 @@
     in, out);
 }
 
+void NEScaleKernel::scale_nhwc(const Window &window)
+{
+    ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(_input, 1, DataType::U8, DataType::S16, DataType::F32);
+
+    // Get data layout and width/height indices
+    const DataLayout data_layout  = _input->info()->data_layout();
+    const int        idx_channels = get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL);
+    const int        idx_width    = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
+    const int        idx_height   = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
+
+    const size_t input_stride_w = _input->info()->strides_in_bytes()[idx_width];
+    const size_t input_stride_h = _input->info()->strides_in_bytes()[idx_height];
+    const size_t input_stride_c = _input->info()->strides_in_bytes()[idx_channels];
+
+    // Compute the ratio between source height and destination height
+    const auto hr = static_cast<float>(_input->info()->dimension(idx_height)) / static_cast<float>(_output->info()->dimension(idx_height));
+
+    // Don't increment in width/height/channels for the input tensor
+    // A pointer to the start of this plane is needed as base for the precomputed offsets
+    Window win_in(window);
+    win_in.set(Window::DimX, Window::Dimension(0, 0, 0));
+    win_in.set(Window::DimY, Window::Dimension(0, 0, 0));
+    win_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
+
+    switch(_input->info()->data_type())
+    {
+        case DataType::U8:
+        {
+            if(_policy == InterpolationPolicy::NEAREST_NEIGHBOR)
+            {
+                scale_nearest_nhwc_core<uint8_t>(_input, _offsets, _output, hr, window, win_in, input_stride_w, input_stride_h, input_stride_c);
+            }
+            else
+            {
+                scale_bilinear_nhwc_core<uint8_t>(_input, _offsets, _dx, _dy, _output, hr,
+                                                  window, win_in, input_stride_w, input_stride_h, input_stride_c, _border_mode);
+            }
+            break;
+        }
+        case DataType::S16:
+        {
+            if(_policy == InterpolationPolicy::NEAREST_NEIGHBOR)
+            {
+                scale_nearest_nhwc_core<int16_t>(_input, _offsets, _output, hr, window, win_in, input_stride_w, input_stride_h, input_stride_c);
+            }
+            else
+            {
+                scale_bilinear_nhwc_core<int16_t>(_input, _offsets, _dx, _dy, _output, hr,
+                                                  window, win_in, input_stride_w, input_stride_h, input_stride_c, _border_mode);
+            }
+            break;
+        }
+        case DataType::F32:
+        {
+            if(_policy == InterpolationPolicy::NEAREST_NEIGHBOR)
+            {
+                scale_nearest_nhwc_core<float>(_input, _offsets, _output, hr, window, win_in, input_stride_w, input_stride_h, input_stride_c);
+            }
+            else
+            {
+                scale_bilinear_nhwc_core<float>(_input, _offsets, _dx, _dy, _output, hr,
+                                                window, win_in, input_stride_w, input_stride_h, input_stride_c, _border_mode);
+            }
+            break;
+        }
+        default:
+            ARM_COMPUTE_ERROR("Not supported");
+            break;
+    }
+}
+
 void NEScaleKernel::run(const Window &window, const ThreadInfo &info)
 {
     ARM_COMPUTE_UNUSED(info);
@@ -526,3 +741,4 @@
 
     (this->*_func)(window);
 }
+} // namespace arm_compute