COMPMID-814: NEScale NHWC support

Change-Id: Ibf5c624a5c5482faa42eb02bc8abe9ae0d65b0d1
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/130608
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
diff --git a/tests/validation/CL/Scale.cpp b/tests/validation/CL/Scale.cpp
index cc4fdb0..3d8750a 100644
--- a/tests/validation/CL/Scale.cpp
+++ b/tests/validation/CL/Scale.cpp
@@ -118,7 +118,9 @@
 
 TEST_SUITE(Float)
 TEST_SUITE(FP32)
-FIXTURE_DATA_TEST_CASE(RunSmall, CLScaleFixture<float>, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::F32)),
+FIXTURE_DATA_TEST_CASE(RunSmall, CLScaleFixture<float>, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType",
+                                                                                                                     DataType::F32)),
+                                                                                                                     framework::dataset::make("DataLayout", { DataLayout::NCHW })),
                                                                                                              framework::dataset::make("InterpolationPolicy", { InterpolationPolicy::NEAREST_NEIGHBOR, InterpolationPolicy::BILINEAR })),
                                                                                                      datasets::BorderModes()),
                                                                                              datasets::SamplingPolicies()))
@@ -130,7 +132,9 @@
     // Validate output
     validate(CLAccessor(_target), _reference, valid_region, tolerance_f32, tolerance_num_f32, tolerance_f32_absolute);
 }
-FIXTURE_DATA_TEST_CASE(RunLarge, CLScaleFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::F32)),
+FIXTURE_DATA_TEST_CASE(RunLarge, CLScaleFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType",
+                                                                                                                 DataType::F32)),
+                                                                                                                 framework::dataset::make("DataLayout", { DataLayout::NCHW })),
                                                                                                                  framework::dataset::make("InterpolationPolicy", { InterpolationPolicy::NEAREST_NEIGHBOR, InterpolationPolicy::BILINEAR })),
                                                                                                          datasets::BorderModes()),
                                                                                                  datasets::SamplingPolicies()))
@@ -144,7 +148,9 @@
 }
 TEST_SUITE_END()
 TEST_SUITE(FP16)
-FIXTURE_DATA_TEST_CASE(RunSmall, CLScaleFixture<half>, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::F16)),
+FIXTURE_DATA_TEST_CASE(RunSmall, CLScaleFixture<half>, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType",
+                                                                                                                    DataType::F16)),
+                                                                                                                    framework::dataset::make("DataLayout", { DataLayout::NCHW })),
                                                                                                             framework::dataset::make("InterpolationPolicy", { InterpolationPolicy::NEAREST_NEIGHBOR, InterpolationPolicy::BILINEAR })),
                                                                                                     datasets::BorderModes()),
                                                                                             datasets::SamplingPolicies()))
@@ -156,8 +162,9 @@
     // Validate output
     validate(CLAccessor(_target), _reference, valid_region, tolerance_f16);
 }
-FIXTURE_DATA_TEST_CASE(RunLarge, CLScaleFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType",
+FIXTURE_DATA_TEST_CASE(RunLarge, CLScaleFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType",
                                                                                                                         DataType::F16)),
+                                                                                                                        framework::dataset::make("DataLayout", { DataLayout::NCHW })),
                                                                                                                 framework::dataset::make("InterpolationPolicy", { InterpolationPolicy::NEAREST_NEIGHBOR, InterpolationPolicy::BILINEAR })),
                                                                                                         datasets::BorderModes()),
                                                                                                 datasets::SamplingPolicies()))
@@ -174,7 +181,9 @@
 
 TEST_SUITE(Integer)
 TEST_SUITE(U8)
-FIXTURE_DATA_TEST_CASE(RunSmall, CLScaleFixture<uint8_t>, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::U8)),
+FIXTURE_DATA_TEST_CASE(RunSmall, CLScaleFixture<uint8_t>, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType",
+                                                                                                                       DataType::U8)),
+                                                                                                                       framework::dataset::make("DataLayout", { DataLayout::NCHW })),
                                                                                                                framework::dataset::make("InterpolationPolicy", { InterpolationPolicy::NEAREST_NEIGHBOR, InterpolationPolicy::BILINEAR })),
                                                                                                        datasets::BorderModes()),
                                                                                                datasets::SamplingPolicies()))
@@ -186,7 +195,9 @@
     // Validate output
     validate(CLAccessor(_target), _reference, valid_region, tolerance_u8);
 }
-FIXTURE_DATA_TEST_CASE(RunLarge, CLScaleFixture<uint8_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::U8)),
+FIXTURE_DATA_TEST_CASE(RunLarge, CLScaleFixture<uint8_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType",
+                                                                                                                   DataType::U8)),
+                                                                                                                   framework::dataset::make("DataLayout", { DataLayout::NCHW })),
                                                                                                                    framework::dataset::make("InterpolationPolicy", { InterpolationPolicy::NEAREST_NEIGHBOR, InterpolationPolicy::BILINEAR })),
                                                                                                            datasets::BorderModes()),
                                                                                                    datasets::SamplingPolicies()))
@@ -200,7 +211,9 @@
 }
 TEST_SUITE_END()
 TEST_SUITE(S16)
-FIXTURE_DATA_TEST_CASE(RunSmall, CLScaleFixture<int16_t>, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::S16)),
+FIXTURE_DATA_TEST_CASE(RunSmall, CLScaleFixture<int16_t>, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType",
+                                                                                                                       DataType::S16)),
+                                                                                                                       framework::dataset::make("DataLayout", { DataLayout::NCHW })),
                                                                                                                framework::dataset::make("InterpolationPolicy", { InterpolationPolicy::NEAREST_NEIGHBOR, InterpolationPolicy::BILINEAR })),
                                                                                                        datasets::BorderModes()),
                                                                                                datasets::SamplingPolicies()))
@@ -212,8 +225,9 @@
     // Validate output
     validate(CLAccessor(_target), _reference, valid_region, tolerance_s16);
 }
-FIXTURE_DATA_TEST_CASE(RunLarge, CLScaleFixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType",
+FIXTURE_DATA_TEST_CASE(RunLarge, CLScaleFixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType",
                                                                                                                    DataType::S16)),
+                                                                                                                   framework::dataset::make("DataLayout", { DataLayout::NCHW })),
                                                                                                                    framework::dataset::make("InterpolationPolicy", { InterpolationPolicy::NEAREST_NEIGHBOR, InterpolationPolicy::BILINEAR })),
                                                                                                            datasets::BorderModes()),
                                                                                                    datasets::SamplingPolicies()))
diff --git a/tests/validation/GLES_COMPUTE/Scale.cpp b/tests/validation/GLES_COMPUTE/Scale.cpp
index 9f670e4..4bfa08f 100644
--- a/tests/validation/GLES_COMPUTE/Scale.cpp
+++ b/tests/validation/GLES_COMPUTE/Scale.cpp
@@ -108,7 +108,9 @@
 
 TEST_SUITE(Float)
 TEST_SUITE(FP16)
-FIXTURE_DATA_TEST_CASE(RunSmall, GCScaleFixture<half>, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::F16)),
+FIXTURE_DATA_TEST_CASE(RunSmall, GCScaleFixture<half>, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType",
+                                                                                                                    DataType::F16)),
+                                                                                                                    framework::dataset::make("DataLayout", { DataLayout::NCHW })),
                                                                                                             framework::dataset::make("InterpolationPolicy", { InterpolationPolicy::NEAREST_NEIGHBOR })),
                                                                                                     datasets::BorderModes()),
                                                                                             datasets::SamplingPolicies()))
@@ -120,8 +122,9 @@
     // Validate output
     validate(GCAccessor(_target), _reference, valid_region, tolerance_f16);
 }
-FIXTURE_DATA_TEST_CASE(RunLarge, GCScaleFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType",
+FIXTURE_DATA_TEST_CASE(RunLarge, GCScaleFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType",
                                                                                                                         DataType::F16)),
+                                                                                                                        framework::dataset::make("DataLayout", { DataLayout::NCHW })),
                                                                                                                 framework::dataset::make("InterpolationPolicy", { InterpolationPolicy::NEAREST_NEIGHBOR })),
                                                                                                         datasets::BorderModes()),
                                                                                                 datasets::SamplingPolicies()))
diff --git a/tests/validation/NEON/Scale.cpp b/tests/validation/NEON/Scale.cpp
index 5f76a0c..b21affd 100644
--- a/tests/validation/NEON/Scale.cpp
+++ b/tests/validation/NEON/Scale.cpp
@@ -55,6 +55,13 @@
     DataType::F32,
 });
 
+/** Scale data types */
+const auto ScaleDataLayouts = framework::dataset::make("DataLayout",
+{
+    DataLayout::NCHW,
+    DataLayout::NHWC,
+});
+
 /** Tolerance */
 constexpr AbsoluteTolerance<uint8_t> tolerance_u8(1);
 constexpr AbsoluteTolerance<int16_t> tolerance_s16(1);
@@ -67,29 +74,42 @@
 TEST_SUITE(NEON)
 TEST_SUITE(Scale)
 
-DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(combine(combine(concat(datasets::SmallShapes(), datasets::LargeShapes()), ScaleDataTypes),
+DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(concat(datasets::SmallShapes(), datasets::LargeShapes()), ScaleDataTypes), ScaleDataLayouts),
                                                                                    framework::dataset::make("InterpolationPolicy", { InterpolationPolicy::NEAREST_NEIGHBOR, InterpolationPolicy::BILINEAR })),
                                                                            datasets::BorderModes()),
                                                                    framework::dataset::make("SamplingPolicy", { SamplingPolicy::CENTER })),
-               shape, data_type, policy, border_mode, sampling_policy)
+               shape, data_type, data_layout, policy, border_mode, sampling_policy)
 {
     std::mt19937                          generator(library->seed());
     std::uniform_real_distribution<float> distribution_float(0.25, 2);
     const float                           scale_x               = distribution_float(generator);
     const float                           scale_y               = distribution_float(generator);
     uint8_t                               constant_border_value = 0;
+    TensorShape                           src_shape             = shape;
     if(border_mode == BorderMode::CONSTANT)
     {
         std::uniform_int_distribution<uint8_t> distribution_u8(0, 255);
         constant_border_value = distribution_u8(generator);
     }
 
+    // Get width/height indices depending on layout
+    const int idx_width  = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
+    const int idx_height = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
+
+    // Change shape in case of NHWC.
+    if(data_layout == DataLayout::NHWC)
+    {
+        permute(src_shape, PermutationVector(2U, 0U, 1U));
+    }
+
+    // Calculate scaled shape
+    TensorShape shape_scaled(src_shape);
+    shape_scaled.set(idx_width, src_shape[idx_width] * scale_x);
+    shape_scaled.set(idx_height, src_shape[idx_height] * scale_y);
+
     // Create tensors
-    Tensor      src = create_tensor<Tensor>(shape, data_type);
-    TensorShape shape_scaled(shape);
-    shape_scaled.set(0, shape[0] * scale_x);
-    shape_scaled.set(1, shape[1] * scale_y);
-    Tensor dst = create_tensor<Tensor>(shape_scaled, data_type);
+    Tensor src = create_tensor<Tensor>(src_shape, data_type, 1, 0, QuantizationInfo(), data_layout);
+    Tensor dst = create_tensor<Tensor>(shape_scaled, data_type, 1, 0, QuantizationInfo(), data_layout);
 
     ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS);
     ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS);
@@ -100,14 +120,26 @@
 
     // Validate valid region
     const ValidRegion dst_valid_region = calculate_valid_region_scale(*(src.info()), shape_scaled, policy, sampling_policy, (border_mode == BorderMode::UNDEFINED));
-
     validate(dst.info()->valid_region(), dst_valid_region);
 
     // Validate padding
-    PaddingCalculator calculator(shape_scaled.x(), 16);
+    int num_elements_processed_x = 16;
+    if(data_layout == DataLayout::NHWC)
+    {
+        num_elements_processed_x = (policy == InterpolationPolicy::BILINEAR) ? 1 : 16 / src.info()->element_size();
+    }
+    PaddingCalculator calculator(shape_scaled.x(), num_elements_processed_x);
     calculator.set_border_mode(border_mode);
 
-    const PaddingSize read_padding(1);
+    PaddingSize read_padding(1);
+    if(data_layout == DataLayout::NHWC)
+    {
+        read_padding = calculator.required_padding(PaddingCalculator::Option::EXCLUDE_BORDER);
+        if(border_mode == BorderMode::CONSTANT && policy == InterpolationPolicy::BILINEAR)
+        {
+            read_padding.top = 1;
+        }
+    }
     const PaddingSize write_padding = calculator.required_padding(PaddingCalculator::Option::EXCLUDE_BORDER);
     validate(src.info()->padding(), read_padding);
     validate(dst.info()->padding(), write_padding);
@@ -118,8 +150,9 @@
 
 TEST_SUITE(Float)
 TEST_SUITE(FP32)
-FIXTURE_DATA_TEST_CASE(RunSmall, NEScaleFixture<float>, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType",
+FIXTURE_DATA_TEST_CASE(RunSmall, NEScaleFixture<float>, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType",
                                                                                                                      DataType::F32)),
+                                                                                                                     framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })),
                                                                                                              framework::dataset::make("InterpolationPolicy", { InterpolationPolicy::NEAREST_NEIGHBOR, InterpolationPolicy::BILINEAR })),
                                                                                                      datasets::BorderModes()),
                                                                                              framework::dataset::make("SamplingPolicy", { SamplingPolicy::CENTER })))
@@ -131,8 +164,9 @@
     // Validate output
     validate(Accessor(_target), _reference, valid_region, tolerance_f32, tolerance_num_f32);
 }
-FIXTURE_DATA_TEST_CASE(RunLarge, NEScaleFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType",
+FIXTURE_DATA_TEST_CASE(RunLarge, NEScaleFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType",
                                                                                                                  DataType::F32)),
+                                                                                                                 framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })),
                                                                                                                  framework::dataset::make("InterpolationPolicy", { InterpolationPolicy::NEAREST_NEIGHBOR, InterpolationPolicy::BILINEAR })),
                                                                                                          datasets::BorderModes()),
                                                                                                  framework::dataset::make("SamplingPolicy", { SamplingPolicy::CENTER })))
@@ -149,8 +183,9 @@
 
 TEST_SUITE(Integer)
 TEST_SUITE(U8)
-FIXTURE_DATA_TEST_CASE(RunSmall, NEScaleFixture<uint8_t>, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType",
+FIXTURE_DATA_TEST_CASE(RunSmall, NEScaleFixture<uint8_t>, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType",
                                                                                                                        DataType::U8)),
+                                                                                                                       framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })),
                                                                                                                framework::dataset::make("InterpolationPolicy", { InterpolationPolicy::NEAREST_NEIGHBOR, InterpolationPolicy::BILINEAR })),
                                                                                                        datasets::BorderModes()),
                                                                                                framework::dataset::make("SamplingPolicy", { SamplingPolicy::CENTER })))
@@ -162,8 +197,9 @@
     // Validate output
     validate(Accessor(_target), _reference, valid_region, tolerance_u8);
 }
-FIXTURE_DATA_TEST_CASE(RunLarge, NEScaleFixture<uint8_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType",
+FIXTURE_DATA_TEST_CASE(RunLarge, NEScaleFixture<uint8_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType",
                                                                                                                    DataType::U8)),
+                                                                                                                   framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })),
                                                                                                                    framework::dataset::make("InterpolationPolicy", { InterpolationPolicy::NEAREST_NEIGHBOR, InterpolationPolicy::BILINEAR })),
                                                                                                            datasets::BorderModes()),
                                                                                                    framework::dataset::make("SamplingPolicy", { SamplingPolicy::CENTER })))
@@ -177,8 +213,9 @@
 }
 TEST_SUITE_END()
 TEST_SUITE(S16)
-FIXTURE_DATA_TEST_CASE(RunSmall, NEScaleFixture<int16_t>, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType",
+FIXTURE_DATA_TEST_CASE(RunSmall, NEScaleFixture<int16_t>, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType",
                                                                                                                        DataType::S16)),
+                                                                                                                       framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })),
                                                                                                                framework::dataset::make("InterpolationPolicy", { InterpolationPolicy::NEAREST_NEIGHBOR, InterpolationPolicy::BILINEAR })),
                                                                                                        datasets::BorderModes()),
                                                                                                framework::dataset::make("SamplingPolicy", { SamplingPolicy::CENTER })))
@@ -190,8 +227,9 @@
     // Validate output
     validate(Accessor(_target), _reference, valid_region, tolerance_s16, tolerance_num_s16);
 }
-FIXTURE_DATA_TEST_CASE(RunLarge, NEScaleFixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType",
+FIXTURE_DATA_TEST_CASE(RunLarge, NEScaleFixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType",
                                                                                                                    DataType::S16)),
+                                                                                                                   framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })),
                                                                                                                    framework::dataset::make("InterpolationPolicy", { InterpolationPolicy::NEAREST_NEIGHBOR, InterpolationPolicy::BILINEAR })),
                                                                                                            datasets::BorderModes()),
                                                                                                    framework::dataset::make("SamplingPolicy", { SamplingPolicy::CENTER })))
diff --git a/tests/validation/fixtures/ScaleFixture.h b/tests/validation/fixtures/ScaleFixture.h
index 604bfb2..ec10231 100644
--- a/tests/validation/fixtures/ScaleFixture.h
+++ b/tests/validation/fixtures/ScaleFixture.h
@@ -44,7 +44,7 @@
 {
 public:
     template <typename...>
-    void setup(TensorShape shape, DataType data_type, InterpolationPolicy policy, BorderMode border_mode, SamplingPolicy sampling_policy)
+    void setup(TensorShape shape, DataType data_type, DataLayout data_layout, InterpolationPolicy policy, BorderMode border_mode, SamplingPolicy sampling_policy)
     {
         constexpr float max_width  = 8192.0f;
         constexpr float max_height = 6384.0f;
@@ -60,13 +60,16 @@
         float                                 scale_x = distribution_float(generator);
         float                                 scale_y = distribution_float(generator);
 
-        scale_x = ((shape.x() * scale_x) > max_width) ? (max_width / shape.x()) : scale_x;
-        scale_y = ((shape.y() * scale_y) > max_height) ? (max_height / shape.y()) : scale_y;
+        const int idx_width  = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
+        const int idx_height = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
+
+        scale_x = ((shape[idx_width] * scale_x) > max_width) ? (max_width / shape[idx_width]) : scale_x;
+        scale_y = ((shape[idx_height] * scale_y) > max_height) ? (max_height / shape[idx_height]) : scale_y;
 
         std::uniform_int_distribution<uint8_t> distribution_u8(0, 255);
         T                                      constant_border_value = static_cast<T>(distribution_u8(generator));
 
-        _target    = compute_target(shape, scale_x, scale_y, policy, border_mode, constant_border_value, sampling_policy);
+        _target    = compute_target(shape, data_layout, scale_x, scale_y, policy, border_mode, constant_border_value, sampling_policy);
         _reference = compute_reference(shape, scale_x, scale_y, policy, border_mode, constant_border_value, sampling_policy);
     }
 
@@ -86,15 +89,25 @@
         }
     }
 
-    TensorType compute_target(const TensorShape &shape, const float scale_x, const float scale_y,
+    TensorType compute_target(TensorShape shape, DataLayout data_layout, const float scale_x, const float scale_y,
                               InterpolationPolicy policy, BorderMode border_mode, T constant_border_value, SamplingPolicy sampling_policy)
     {
+        // Change shape in case of NHWC.
+        if(data_layout == DataLayout::NHWC)
+        {
+            permute(shape, PermutationVector(2U, 0U, 1U));
+        }
+
         // Create tensors
-        TensorType  src = create_tensor<TensorType>(shape, _data_type);
+        TensorType src = create_tensor<TensorType>(shape, _data_type, 1, 0, QuantizationInfo(), data_layout);
+
+        const int idx_width  = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
+        const int idx_height = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
+
         TensorShape shape_scaled(shape);
-        shape_scaled.set(0, shape[0] * scale_x);
-        shape_scaled.set(1, shape[1] * scale_y);
-        TensorType dst = create_tensor<TensorType>(shape_scaled, _data_type);
+        shape_scaled.set(idx_width, shape[idx_width] * scale_x);
+        shape_scaled.set(idx_height, shape[idx_height] * scale_y);
+        TensorType dst = create_tensor<TensorType>(shape_scaled, _data_type, 1, 0, QuantizationInfo(), data_layout);
 
         // Create and configure function
         FunctionType scale;
@@ -123,7 +136,7 @@
                                       InterpolationPolicy policy, BorderMode border_mode, T constant_border_value, SamplingPolicy sampling_policy)
     {
         // Create reference
-        SimpleTensor<T> src{ shape, _data_type };
+        SimpleTensor<T> src{ shape, _data_type, 1, 0, QuantizationInfo() };
 
         // Fill reference
         fill(src);
diff --git a/tests/validation/reference/Scale.cpp b/tests/validation/reference/Scale.cpp
index 5c9e956..f8a8b88 100644
--- a/tests/validation/reference/Scale.cpp
+++ b/tests/validation/reference/Scale.cpp
@@ -23,6 +23,7 @@
  */
 
 #include "Scale.h"
+
 #include "Utils.h"
 #include "arm_compute/core/utils/misc/Utility.h"
 #include "support/ToolchainSupport.h"