COMPMID-2060: Support different qinfo in PoolingLayer

CL and Neon back ends now support different qinfos

Change-Id: I638d5f258ab2f99b40659601b4c5398d2c34c43b
Signed-off-by: Pablo Tello <pablo.tello@arm.com>
Reviewed-on: https://review.mlplatform.org/c/927
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Michalis Spyrou <michalis.spyrou@arm.com>
Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
diff --git a/src/core/CL/cl_kernels/pooling_layer_quantized.cl b/src/core/CL/cl_kernels/pooling_layer_quantized.cl
index 198250b..919b76e 100644
--- a/src/core/CL/cl_kernels/pooling_layer_quantized.cl
+++ b/src/core/CL/cl_kernels/pooling_layer_quantized.cl
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017-2018 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -23,6 +23,19 @@
  */
 #include "helpers.h"
 
+#if defined(OFFSET_IN1) && defined(OFFSET_OUT) && defined(SCALE_IN1) && defined(SCALE_OUT)
+#define VEC_FLOAT(VEC_SIZE)        \
+    VEC_DATA_TYPE(float, VEC_SIZE) \
+#define VEC_INT(VEC_SIZE) VEC_DATA_TYPE(int, VEC_SIZE) #define VEC_UCHAR(VEC_SIZE) VEC_DATA_TYPE(uchar, VEC_SIZE) #define CONVERT_RTE(x, type)(convert_##type##_rte((x)))
+#define CONVERT_DOWN(x, type) CONVERT_RTE(x, type)
+#define REQUANTIZE(VEC_SIZE, input, in_offset, out_offset, in_scale, out_scale, res)                                                                                  \
+    {                                                                                                                                                                 \
+        const VEC_FLOAT(VEC_SIZE) in_f32  = (CONVERT(input, VEC_FLOAT(VEC_SIZE)) - (VEC_FLOAT(VEC_SIZE))((float)in_offset)) * (VEC_FLOAT(VEC_SIZE))((float)in_scale); \
+        const VEC_FLOAT(VEC_SIZE) out_f32 = in_f32 / ((VEC_FLOAT(VEC_SIZE))(float)out_scale) + ((VEC_FLOAT(VEC_SIZE))((float)out_offset));                            \
+        res                               = CONVERT_SAT(CONVERT_DOWN(out_f32, VEC_INT(VEC_SIZE)), VEC_UCHAR(VEC_SIZE));                                               \
+    }
+#endif /* defined(OFFSET_IN1) && defined(OFFSET_OUT) && defined(SCALE_IN1) && defined(SCALE_OUT) */
+
 #if defined(POOL_AVG)
 #define POOL_OP(x, y) ((x) + (y))
 #else /* defined(POOL_AVG) */
@@ -118,8 +131,22 @@
     res = round(DIV_OP(res, calculate_avg_scale(POOL_SIZE_X, POOL_SIZE_Y, MAX_WIDTH, MAX_HEIGHT, PAD_X, PAD_Y, STRIDE_X, STRIDE_Y)));
 #endif /* defined(POOL_AVG) */
 
-    // Store result
-    *(__global uchar *)output.ptr = convert_uchar(res);
+    uchar result_u8 = convert_uchar(res);
+
+#if defined(OFFSET_IN1) && defined(OFFSET_OUT) && defined(SCALE_IN1) && defined(SCALE_OUT)
+
+    const float result_f32   = convert_float(result_u8);
+    const float input_offset = (float)OFFSET_IN1;
+    const float input_scale  = (float)SCALE_IN1;
+    const float scale_out    = (float)SCALE_OUT;
+    const float offset_out   = (float)OFFSET_OUT;
+    const float in_f32       = (result_f32 - input_offset) * input_scale;
+    const float out_f32      = in_f32 / scale_out + offset_out;
+    result_u8                = convert_uchar_sat(convert_int_rte(out_f32));
+
+#endif /* defined(OFFSET_IN1) && defined(OFFSET_OUT) && defined(SCALE_IN1) && defined(SCALE_OUT) */
+
+    *(__global uchar *)output.ptr = result_u8;
 }
 
 int calculate_avg_scale_nhwc(const int pool_size_x, const int pool_size_y, int upper_bound_w, int upper_bound_h,
@@ -217,6 +244,11 @@
     vdata = convert_int8(round(DIV_OP_NHWC(vdata, calculate_avg_scale_nhwc(POOL_SIZE_X, POOL_SIZE_Y, MAX_WIDTH, MAX_HEIGHT, PAD_X, PAD_Y, STRIDE_X, STRIDE_Y))));
 #endif /* defined(POOL_AVG) */
 
+    uchar8 out_u8 = convert_uchar8(vdata);
+#if defined(OFFSET_IN1) && defined(OFFSET_OUT) && defined(SCALE_IN1) && defined(SCALE_OUT)
+    REQUANTIZE(8, out_u8, OFFSET_IN1, OFFSET_OUT, SCALE_IN1, SCALE_OUT, out_u8);
+#endif /* defined(OFFSET_IN1) && defined(OFFSET_OUT) && defined(SCALE_IN1) && defined(SCALE_OUT) */
+
     // Store result
-    vstore8(convert_uchar8(vdata), 0, (__global uchar *)output.ptr);
-}
\ No newline at end of file
+    vstore8(out_u8, 0, (__global uchar *)output.ptr);
+}
diff --git a/src/core/CL/kernels/CLPoolingLayerKernel.cpp b/src/core/CL/kernels/CLPoolingLayerKernel.cpp
index 7081688..7ccbda9 100644
--- a/src/core/CL/kernels/CLPoolingLayerKernel.cpp
+++ b/src/core/CL/kernels/CLPoolingLayerKernel.cpp
@@ -78,7 +78,6 @@
     {
         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(input, output);
-        ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(input, output);
         TensorInfo out_info(TensorInfo(compute_pool_shape(*input, pool_info), 1, output->data_type()));
         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(output, &out_info);
     }
@@ -201,6 +200,17 @@
     const int pool_pad_top  = pad_stride_info.pad_top();
     const int pool_pad_left = pad_stride_info.pad_left();
 
+    // Set build options
+    CLBuildOptions build_opts;
+
+    if(is_data_type_quantized_asymmetric(input->info()->data_type()) && input->info()->quantization_info() != output->info()->quantization_info())
+    {
+        build_opts.add_option("-DOFFSET_IN1=" + float_to_string_with_full_precision(input->info()->quantization_info().offset));
+        build_opts.add_option("-DOFFSET_OUT=" + float_to_string_with_full_precision(output->info()->quantization_info().offset));
+        build_opts.add_option("-DSCALE_IN1=" + float_to_string_with_full_precision(input->info()->quantization_info().scale));
+        build_opts.add_option("-DSCALE_OUT=" + float_to_string_with_full_precision(output->info()->quantization_info().scale));
+    }
+
     // Check output dimensions
     auto_init(input->info(), output->info(), pool_info);
     ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), pool_info));
@@ -212,8 +222,6 @@
 
     const DataType data_type = input->info()->data_type();
 
-    // Set build options
-    CLBuildOptions build_opts;
     build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(data_type));
     build_opts.add_option("-DPOOL_" + string_from_pooling_type(pool_type));
     build_opts.add_option("-DSTRIDE_X=" + support::cpp11::to_string(pool_stride_x));
@@ -222,6 +230,7 @@
     build_opts.add_option("-DPAD_Y=" + support::cpp11::to_string(pool_pad_top));
     build_opts.add_option("-DPOOL_SIZE_X=" + support::cpp11::to_string(pool_size_x));
     build_opts.add_option("-DPOOL_SIZE_Y=" + support::cpp11::to_string(pool_size_y));
+
     build_opts.add_option_if(data_type == DataType::F16, "-DFP16");
 
     // Create kernel
diff --git a/src/core/NEON/kernels/NEPoolingLayerKernel.cpp b/src/core/NEON/kernels/NEPoolingLayerKernel.cpp
index d00a4af..308fad5 100644
--- a/src/core/NEON/kernels/NEPoolingLayerKernel.cpp
+++ b/src/core/NEON/kernels/NEPoolingLayerKernel.cpp
@@ -138,7 +138,6 @@
     {
         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(input, output);
-        ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(input, output);
         ARM_COMPUTE_RETURN_ERROR_ON((output->dimension(get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::WIDTH)) != pooled_w)
                                     || (output->dimension(get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::HEIGHT)) != pooled_h));
     }
@@ -640,6 +639,15 @@
             }
         }
 
+        const QuantizationInfo &input_qinfo  = _input->info()->quantization_info();
+        const QuantizationInfo &output_qinfo = _output->info()->quantization_info();
+        if(input_qinfo != output_qinfo)
+        {
+            const auto requantized_output = vquantize(vdequantize(vcombine_u8(lower_res, upper_res), input_qinfo), output_qinfo);
+            lower_res                     = vget_low_u8(requantized_output);
+            upper_res                     = vget_high_u8(requantized_output);
+        }
+
         // Store result
         if(pool_stride_x == 1)
         {
@@ -1641,6 +1649,11 @@
         }
 
         // Store result
+        const QuantizationInfo &input_qinfo  = _input->info()->quantization_info();
+        const QuantizationInfo &output_qinfo = _output->info()->quantization_info();
+        res                                  = (input_qinfo != output_qinfo) ? sqcvt_qasymm8_f32(scvt_f32_qasymm8(res, input_qinfo.scale, input_qinfo.offset), output_qinfo.scale,
+                                                                                                 output_qinfo.offset) :
+                                               res;
         *(reinterpret_cast<uint8_t *>(output.ptr())) = res;
     },
     input, output);
@@ -1663,7 +1676,9 @@
     const int upper_bound_w = _input->info()->dimension(1) + (exclude_padding ? 0 : pool_pad_right);
     const int upper_bound_h = _input->info()->dimension(2) + (exclude_padding ? 0 : pool_pad_bottom);
 
-    const float32x4_t half_scale_v = vdupq_n_f32(0.5f);
+    const float32x4_t       half_scale_v = vdupq_n_f32(0.5f);
+    const QuantizationInfo &input_qinfo  = _input->info()->quantization_info();
+    const QuantizationInfo &output_qinfo = _output->info()->quantization_info();
 
     execute_window_loop(window, [&](const Coordinates & id)
     {
@@ -1713,6 +1728,12 @@
 
             uint8x8_t res1 = vmovn_u16(vcombine_u16(vmovn_u32(vres1), vmovn_u32(vres2)));
             uint8x8_t res2 = vmovn_u16(vcombine_u16(vmovn_u32(vres3), vmovn_u32(vres4)));
+            if(input_qinfo != output_qinfo)
+            {
+                const auto requantized_output = vquantize(vdequantize(vcombine_u8(res1, res2), input_qinfo), output_qinfo);
+                res1                          = vget_low_u8(requantized_output);
+                res2                          = vget_high_u8(requantized_output);
+            }
 
             // Store result
             vst1_u8(output.ptr(), res1);
@@ -1733,7 +1754,7 @@
             }
 
             // Store result
-            vst1q_u8(output.ptr(), vres);
+            vst1q_u8(output.ptr(), (input_qinfo != output_qinfo) ? vquantize(vdequantize(vres, input_qinfo), output_qinfo) : vres);
         }
     },
     input, output);
diff --git a/tests/validation/CL/PoolingLayer.cpp b/tests/validation/CL/PoolingLayer.cpp
index 6afafe6..7d79f3f 100644
--- a/tests/validation/CL/PoolingLayer.cpp
+++ b/tests/validation/CL/PoolingLayer.cpp
@@ -74,6 +74,8 @@
 constexpr AbsoluteTolerance<float>   tolerance_f32(0.001f); /**< Tolerance value for comparing reference's output against implementation's output for 32-bit floating-point type */
 constexpr AbsoluteTolerance<float>   tolerance_f16(0.01f);  /**< Tolerance value for comparing reference's output against implementation's output for 16-bit floating-point type */
 constexpr AbsoluteTolerance<uint8_t> tolerance_qasymm8(1);  /**< Tolerance value for comparing reference's output against implementation's output for 8-bit asymmetric type */
+const auto                           pool_data_layout_dataset = framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC });
+
 } // namespace
 
 TEST_SUITE(CL)
@@ -133,7 +135,7 @@
 FIXTURE_DATA_TEST_CASE(RunSmall, CLPoolingLayerFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), combine(PoolingLayerDatasetFPSmall,
                                                                                                                   framework::dataset::make("DataType",
                                                                                                                           DataType::F32))),
-                                                                                                          framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })))
+                                                                                                          pool_data_layout_dataset))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_f32);
@@ -141,7 +143,7 @@
 FIXTURE_DATA_TEST_CASE(RunLarge, CLPoolingLayerFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), combine(PoolingLayerDatasetFP,
                                                                                                                 framework::dataset::make("DataType",
                                                                                                                         DataType::F32))),
-                                                                                                        framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })))
+                                                                                                        pool_data_layout_dataset))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_f32);
@@ -151,14 +153,14 @@
 TEST_SUITE(FP16)
 FIXTURE_DATA_TEST_CASE(RunSmall, CLPoolingLayerFixture<half>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), combine(PoolingLayerDatasetFPSmall,
                                                                                                                  framework::dataset::make("DataType", DataType::F16))),
-                                                                                                         framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })))
+                                                                                                         pool_data_layout_dataset))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_f16);
 }
 FIXTURE_DATA_TEST_CASE(RunLarge, CLPoolingLayerFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), combine(PoolingLayerDatasetFP,
                                                                                                                framework::dataset::make("DataType", DataType::F16))),
-                                                                                                       framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })))
+                                                                                                       pool_data_layout_dataset))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_f16);
@@ -172,20 +174,16 @@
 using CLPoolingLayerQuantizedFixture = PoolingLayerValidationQuantizedFixture<CLTensor, CLAccessor, CLPoolingLayer, T>;
 
 TEST_SUITE(QASYMM8)
-FIXTURE_DATA_TEST_CASE(RunSmall, CLPoolingLayerQuantizedFixture<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallShapes(), combine(PoolingLayerDatasetQASYMM8Small,
+FIXTURE_DATA_TEST_CASE(RunSmall, CLPoolingLayerQuantizedFixture<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), combine(PoolingLayerDatasetQASYMM8Small,
                                                                                                                      framework::dataset::make("DataType", DataType::QASYMM8))),
-                                                                                                                     framework::dataset::make("QuantizationInfo", { QuantizationInfo(2.f / 255, 127),
-                                                                                                                             QuantizationInfo(7.f / 255, 123)
-                                                                                                                                                                  })),
-                                                                                                                     framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })))
+                                                                                                                     pool_data_layout_dataset))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_qasymm8);
 }
-FIXTURE_DATA_TEST_CASE(RunLarge, CLPoolingLayerQuantizedFixture<uint8_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), combine(PoolingLayerDatasetQASYMM8,
+FIXTURE_DATA_TEST_CASE(RunLarge, CLPoolingLayerQuantizedFixture<uint8_t>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), combine(PoolingLayerDatasetQASYMM8,
                                                                                                                    framework::dataset::make("DataType", DataType::QASYMM8))),
-                                                                                                                   framework::dataset::make("QuantizationInfo", { QuantizationInfo(1.f / 255, 0) })),
-                                                                                                                   framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })))
+                                                                                                                   pool_data_layout_dataset))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_qasymm8);
diff --git a/tests/validation/NEON/PoolingLayer.cpp b/tests/validation/NEON/PoolingLayer.cpp
index 9a15775..129f53b 100644
--- a/tests/validation/NEON/PoolingLayer.cpp
+++ b/tests/validation/NEON/PoolingLayer.cpp
@@ -67,6 +67,8 @@
 constexpr AbsoluteTolerance<float> tolerance_f16(0.01f);   /**< Tolerance value for comparing reference's output against implementation's output for float types */
 #endif                                                     /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
 constexpr AbsoluteTolerance<uint8_t> tolerance_qasymm8(1); /**< Tolerance value for comparing reference's output against implementation's output for 8-bit asymmetric type */
+const auto                           pool_data_layout_dataset = framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC });
+
 } // namespace
 
 TEST_SUITE(NEON)
@@ -124,7 +126,7 @@
 FIXTURE_DATA_TEST_CASE(RunSmall, NEPoolingLayerFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), combine(PoolingLayerDatasetFPSmall,
                                                                                                                   framework::dataset::make("DataType",
                                                                                                                           DataType::F32))),
-                                                                                                          framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })))
+                                                                                                          pool_data_layout_dataset))
 {
     // Validate output
     validate(Accessor(_target), _reference, tolerance_f32);
@@ -132,7 +134,7 @@
 FIXTURE_DATA_TEST_CASE(RunLarge, NEPoolingLayerFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), combine(PoolingLayerDatasetFP,
                                                                                                                 framework::dataset::make("DataType",
                                                                                                                         DataType::F32))),
-                                                                                                        framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })))
+                                                                                                        pool_data_layout_dataset))
 {
     // Validate output
     validate(Accessor(_target), _reference, tolerance_f32);
@@ -143,14 +145,14 @@
 TEST_SUITE(FP16)
 FIXTURE_DATA_TEST_CASE(RunSmall, NEPoolingLayerFixture<half>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), combine(PoolingLayerDatasetFPSmall,
                                                                                                                  framework::dataset::make("DataType", DataType::F16))),
-                                                                                                         framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })))
+                                                                                                         pool_data_layout_dataset))
 {
     // Validate output
     validate(Accessor(_target), _reference, tolerance_f16);
 }
 FIXTURE_DATA_TEST_CASE(RunLarge, NEPoolingLayerFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), combine(PoolingLayerDatasetFP,
                                                                                                                framework::dataset::make("DataType", DataType::F16))),
-                                                                                                       framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })))
+                                                                                                       pool_data_layout_dataset))
 {
     // Validate output
     validate(Accessor(_target), _reference, tolerance_f16);
@@ -165,20 +167,16 @@
 using NEPoolingLayerQuantizedFixture = PoolingLayerValidationQuantizedFixture<Tensor, Accessor, NEPoolingLayer, T>;
 
 TEST_SUITE(QASYMM8)
-FIXTURE_DATA_TEST_CASE(RunSmall, NEPoolingLayerQuantizedFixture<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallShapes(), combine(PoolingLayerDatasetQASYMM8Small,
+FIXTURE_DATA_TEST_CASE(RunSmall, NEPoolingLayerQuantizedFixture<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), combine(PoolingLayerDatasetQASYMM8Small,
                                                                                                                      framework::dataset::make("DataType", DataType::QASYMM8))),
-                                                                                                                     framework::dataset::make("QuantizationInfo", { QuantizationInfo(2.f / 255, 127),
-                                                                                                                             QuantizationInfo(7.f / 255, 123)
-                                                                                                                                                                  })),
-                                                                                                                     framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })))
+                                                                                                                     pool_data_layout_dataset))
 {
     // Validate output
     validate(Accessor(_target), _reference, tolerance_qasymm8);
 }
-FIXTURE_DATA_TEST_CASE(RunLarge, NEPoolingLayerQuantizedFixture<uint8_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), combine(PoolingLayerDatasetQASYMM8,
+FIXTURE_DATA_TEST_CASE(RunLarge, NEPoolingLayerQuantizedFixture<uint8_t>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), combine(PoolingLayerDatasetQASYMM8,
                                                                                                                    framework::dataset::make("DataType", DataType::QASYMM8))),
-                                                                                                                   framework::dataset::make("QuantizationInfo", { QuantizationInfo(1.f / 255, 0) })),
-                                                                                                                   framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })))
+                                                                                                                   pool_data_layout_dataset))
 {
     // Validate output
     validate(Accessor(_target), _reference, tolerance_qasymm8);
diff --git a/tests/validation/fixtures/PoolingLayerFixture.h b/tests/validation/fixtures/PoolingLayerFixture.h
index 3e34f98..1813ef4 100644
--- a/tests/validation/fixtures/PoolingLayerFixture.h
+++ b/tests/validation/fixtures/PoolingLayerFixture.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017-2018 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -26,6 +26,7 @@
 
 #include "arm_compute/core/TensorShape.h"
 #include "arm_compute/core/Types.h"
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
 #include "arm_compute/runtime/Tensor.h"
 #include "tests/AssetsLibrary.h"
 #include "tests/Globals.h"
@@ -47,13 +48,16 @@
 {
 public:
     template <typename...>
-    void setup(TensorShape shape, PoolingLayerInfo pool_info, DataType data_type, DataLayout data_layout, QuantizationInfo quantization_info)
+    void setup(TensorShape shape, PoolingLayerInfo pool_info, DataType data_type, DataLayout data_layout)
     {
-        _quantization_info = quantization_info;
-        _pool_info         = pool_info;
+        std::mt19937                    gen(library->seed());
+        std::uniform_int_distribution<> offset_dis(0, 20);
+        const QuantizationInfo          input_qinfo(1.f / 255.f, offset_dis(gen));
+        const QuantizationInfo          output_qinfo(1.f / 255.f, offset_dis(gen));
 
-        _target    = compute_target(shape, pool_info, data_type, data_layout, quantization_info);
-        _reference = compute_reference(shape, pool_info, data_type, quantization_info);
+        _pool_info = pool_info;
+        _target    = compute_target(shape, pool_info, data_type, data_layout, input_qinfo, output_qinfo);
+        _reference = compute_reference(shape, pool_info, data_type, input_qinfo, output_qinfo);
     }
 
 protected:
@@ -72,7 +76,7 @@
     }
 
     TensorType compute_target(TensorShape shape, PoolingLayerInfo info,
-                              DataType data_type, DataLayout data_layout, QuantizationInfo quantization_info)
+                              DataType data_type, DataLayout data_layout, QuantizationInfo input_qinfo, QuantizationInfo output_qinfo)
     {
         // Change shape in case of NHWC.
         if(data_layout == DataLayout::NHWC)
@@ -81,8 +85,9 @@
         }
 
         // Create tensors
-        TensorType src = create_tensor<TensorType>(shape, data_type, 1, quantization_info, data_layout);
-        TensorType dst;
+        TensorType        src       = create_tensor<TensorType>(shape, data_type, 1, input_qinfo, data_layout);
+        const TensorShape dst_shape = misc::shape_calculator::compute_pool_shape(*(src.info()), info);
+        TensorType        dst       = create_tensor<TensorType>(dst_shape, data_type, 1, output_qinfo, data_layout);
 
         // Create and configure function
         FunctionType pool_layer;
@@ -107,21 +112,19 @@
         return dst;
     }
 
-    SimpleTensor<T> compute_reference(const TensorShape &shape, PoolingLayerInfo info,
-                                      DataType data_type, QuantizationInfo quantization_info)
+    SimpleTensor<T> compute_reference(const TensorShape &shape, PoolingLayerInfo info, DataType data_type, QuantizationInfo input_qinfo, QuantizationInfo output_qinfo)
     {
         // Create reference
-        SimpleTensor<T> src{ shape, data_type, 1, quantization_info };
+        SimpleTensor<T> src{ shape, data_type, 1, input_qinfo };
 
         // Fill reference
         fill(src);
 
-        return reference::pooling_layer<T>(src, info);
+        return reference::pooling_layer<T>(src, info, output_qinfo);
     }
 
     TensorType       _target{};
     SimpleTensor<T>  _reference{};
-    QuantizationInfo _quantization_info{};
     PoolingLayerInfo _pool_info{};
 };
 
@@ -133,7 +136,7 @@
     void setup(TensorShape shape, PoolingType pool_type, Size2D pool_size, PadStrideInfo pad_stride_info, bool exclude_padding, DataType data_type, DataLayout data_layout)
     {
         PoolingLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, PoolingLayerInfo(pool_type, pool_size, pad_stride_info, exclude_padding),
-                                                                                               data_type, data_layout, QuantizationInfo());
+                                                                                               data_type, data_layout);
     }
 };
 
@@ -142,11 +145,10 @@
 {
 public:
     template <typename...>
-    void setup(TensorShape shape, PoolingType pool_type, Size2D pool_size, PadStrideInfo pad_stride_info, bool exclude_padding, DataType data_type,
-               QuantizationInfo quantization_info, DataLayout data_layout = DataLayout::NCHW)
+    void setup(TensorShape shape, PoolingType pool_type, Size2D pool_size, PadStrideInfo pad_stride_info, bool exclude_padding, DataType data_type, DataLayout data_layout = DataLayout::NCHW)
     {
         PoolingLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, PoolingLayerInfo(pool_type, pool_size, pad_stride_info, exclude_padding),
-                                                                                               data_type, data_layout, quantization_info);
+                                                                                               data_type, data_layout);
     }
 };
 
@@ -157,7 +159,7 @@
     template <typename...>
     void setup(TensorShape src_shape, PoolingLayerInfo pool_info, DataType data_type)
     {
-        PoolingLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(src_shape, pool_info, data_type, DataLayout::NCHW, QuantizationInfo());
+        PoolingLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(src_shape, pool_info, data_type, DataLayout::NCHW);
     }
 };
 
@@ -168,7 +170,7 @@
     template <typename...>
     void setup(TensorShape shape, PoolingType pool_type, DataType data_type, DataLayout data_layout = DataLayout::NCHW)
     {
-        PoolingLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, PoolingLayerInfo(pool_type), data_type, DataLayout::NCHW, QuantizationInfo());
+        PoolingLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, PoolingLayerInfo(pool_type), data_type, DataLayout::NCHW);
     }
 };
 
diff --git a/tests/validation/reference/PoolingLayer.cpp b/tests/validation/reference/PoolingLayer.cpp
index e617c93..f4112a4 100644
--- a/tests/validation/reference/PoolingLayer.cpp
+++ b/tests/validation/reference/PoolingLayer.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017-2018 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -38,8 +38,9 @@
 using namespace arm_compute::misc::shape_calculator;
 
 template <typename T>
-SimpleTensor<T> pooling_layer(const SimpleTensor<T> &src, const PoolingLayerInfo &info)
+SimpleTensor<T> pooling_layer(const SimpleTensor<T> &src, const PoolingLayerInfo &info, const QuantizationInfo &output_qinfo)
 {
+    ARM_COMPUTE_UNUSED(output_qinfo); // requantization occurs in pooling_layer<uint8_t>
     ARM_COMPUTE_ERROR_ON(info.is_global_pooling() && (src.shape().x() != src.shape().y()));
 
     // Create reference
@@ -152,16 +153,16 @@
 }
 
 template <>
-SimpleTensor<uint8_t> pooling_layer<uint8_t>(const SimpleTensor<uint8_t> &src, const PoolingLayerInfo &info)
+SimpleTensor<uint8_t> pooling_layer<uint8_t>(const SimpleTensor<uint8_t> &src, const PoolingLayerInfo &info, const QuantizationInfo &output_qinfo)
 {
     SimpleTensor<float>   src_tmp = convert_from_asymmetric(src);
-    SimpleTensor<float>   dst_tmp = pooling_layer<float>(src_tmp, info);
-    SimpleTensor<uint8_t> dst     = convert_to_asymmetric(dst_tmp, src.quantization_info());
+    SimpleTensor<float>   dst_tmp = pooling_layer<float>(src_tmp, info, output_qinfo);
+    SimpleTensor<uint8_t> dst     = convert_to_asymmetric(dst_tmp, output_qinfo);
     return dst;
 }
 
-template SimpleTensor<float> pooling_layer(const SimpleTensor<float> &src, const PoolingLayerInfo &info);
-template SimpleTensor<half> pooling_layer(const SimpleTensor<half> &src, const PoolingLayerInfo &info);
+template SimpleTensor<float> pooling_layer(const SimpleTensor<float> &src, const PoolingLayerInfo &info, const QuantizationInfo &output_qinfo);
+template SimpleTensor<half> pooling_layer(const SimpleTensor<half> &src, const PoolingLayerInfo &info, const QuantizationInfo &output_qinfo);
 } // namespace reference
 } // namespace validation
 } // namespace test
diff --git a/tests/validation/reference/PoolingLayer.h b/tests/validation/reference/PoolingLayer.h
index 0097789..1c0b7ff 100644
--- a/tests/validation/reference/PoolingLayer.h
+++ b/tests/validation/reference/PoolingLayer.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017-2018 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -36,7 +36,7 @@
 namespace reference
 {
 template <typename T>
-SimpleTensor<T> pooling_layer(const SimpleTensor<T> &src, const PoolingLayerInfo &info);
+SimpleTensor<T> pooling_layer(const SimpleTensor<T> &src, const PoolingLayerInfo &info, const QuantizationInfo &output_qinfo);
 } // namespace reference
 } // namespace validation
 } // namespace test