IVGCVSW-3221 Refactor Mean ref workload and tests

 * Renamed RefMeanFloat32Workload and RefMeanUint8Workload
   to RefMeanWorkload, updated references to reflect this
   change.
 * Refactored RefFloorWorkload to use Decoders/Encoders,
   to support the use of multiple data types.
 * Deleted reference Unit8 Mean tests as they were
   duplicates of the Float32 tests. Refactored these tests
   to support multiple data types and updated references.
 * Adjusted the values used in the tests' input tensors so
   that they are more like floating point numbers
   e.g. change 1.0f to 1.5f.
 * Replace size_t with unsigned int in Mean ref workload,
   for better compatibility with the Encoder/Decoder,
   removed some unnecessary casts after this.
 * Added ValidateTensorDataTypesMatch() function to
   WorkloadData.cpp, added CreateIncorrectDimensionsErrorMsg
   function to RefLayerSupport.cpp.
 * Added passing and failing tests for ref IsMeanSupported.

Signed-off-by: James Conroy <james.conroy@arm.com>
Change-Id: Id3d44463d1385255c727a497d4026d21a49e7eb2
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index a1d00c6..1505078 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -271,6 +271,20 @@
     }
 }
 
+//---------------------------------------------------------------
+void ValidateTensorDataTypesMatch(const TensorInfo& first,
+                                  const TensorInfo& second,
+                                  std::string const& descName,
+                                  std::string const& firstName,
+                                  std::string const& secondName)
+{
+    if (first.GetDataType() != second.GetDataType())
+    {
+        throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
+                                       " must have identical data types.");
+    }
+}
+
 } //namespace
 
 void QueueDescriptor::ValidateInputsOutputs(const std::string& descName,
@@ -1275,25 +1289,40 @@
 
 void MeanQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
 {
-    ValidateNumInputs(workloadInfo, "MeanQueueDescriptor", 1);
-    ValidateNumOutputs(workloadInfo, "MeanQueueDescriptor", 1);
+    const std::string meanQueueDescString = "MeanQueueDescriptor";
+
+    ValidateNumInputs(workloadInfo, meanQueueDescString, 1);
+    ValidateNumOutputs(workloadInfo, meanQueueDescString, 1);
+
+    std::vector<DataType> supportedTypes =
+    {
+        DataType::Float32,
+        DataType::Float16,
+        DataType::QuantisedAsymm8,
+        DataType::QuantisedSymm16
+    };
 
     const TensorInfo& input  = workloadInfo.m_InputTensorInfos[0];
     const TensorInfo& output = workloadInfo.m_OutputTensorInfos[0];
 
+    // First check if input tensor data type is supported, then
+    // check if this data type matches the output tensor data type
+    ValidateDataTypes(input,  supportedTypes, meanQueueDescString);
+    ValidateTensorDataTypesMatch(input, output, meanQueueDescString, "input", "output");
+
     if (m_Parameters.m_KeepDims)
     {
-        ValidateTensorNumDimensions(output, "MeanQueueDescriptor", input.GetNumDimensions(), "output");
+        ValidateTensorNumDimensions(output, meanQueueDescString, input.GetNumDimensions(), "output");
     }
     else if (m_Parameters.m_Axis.empty())
     {
-        ValidateTensorNumDimensions(output, "MeanQueueDescriptor", 1, "output");
+        ValidateTensorNumDimensions(output, meanQueueDescString, 1, "output");
     }
     else
     {
         auto outputDim = input.GetNumDimensions() - boost::numeric_cast<unsigned int>(m_Parameters.m_Axis.size());
         ValidateTensorNumDimensions(output,
-                                    "MeanQueueDescriptor",
+                                    meanQueueDescString,
                                     outputDim > 0 ? outputDim : 1,
                                     "output");
     }
diff --git a/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp b/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp
index fa6ec10..ff632fc 100644
--- a/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp
+++ b/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp
@@ -633,5 +633,34 @@
     return result;
 }
 
+// Tests that IsMeanSupported fails when input tensor dimensions
+// do not match output tensor dimensions when keepDims == true
+template<typename FactoryType, armnn::DataType InputDataType , armnn::DataType OutputDataType>
+bool IsMeanLayerNotSupportedTests(std::string& reasonIfUnsupported)
+{
+    armnn::Graph graph;
+    static const std::vector<unsigned> axes = {};
+    // Set keepDims == true
+    armnn::MeanDescriptor desc(axes, true);
+
+    armnn::Layer* const layer = graph.AddLayer<armnn::MeanLayer>(desc, "LayerName");
+
+    armnn::Layer* const input = graph.AddLayer<armnn::InputLayer>(0, "input");
+    armnn::Layer* const output = graph.AddLayer<armnn::OutputLayer>(0, "output");
+
+    // Mismatching number of tensor dimensions
+    armnn::TensorInfo inputTensorInfo({1, 1, 1, 1}, InputDataType);
+    armnn::TensorInfo outputTensorInfo({1, 1}, OutputDataType);
+
+    input->GetOutputSlot(0).Connect(layer->GetInputSlot(0));
+    input->GetOutputHandler(0).SetTensorInfo(inputTensorInfo);
+    layer->GetOutputSlot(0).Connect(output->GetInputSlot(0));
+    layer->GetOutputHandler(0).SetTensorInfo(outputTensorInfo);
+
+    bool result = FactoryType::IsLayerSupported(*layer, InputDataType, reasonIfUnsupported);
+
+    return result;
+}
+
 
 } //namespace
diff --git a/src/backends/backendsCommon/test/LayerTests.cpp b/src/backends/backendsCommon/test/LayerTests.cpp
index 9d40197..55e799e 100644
--- a/src/backends/backendsCommon/test/LayerTests.cpp
+++ b/src/backends/backendsCommon/test/LayerTests.cpp
@@ -8368,237 +8368,6 @@
     return PermuteFloat32ValueSet3TestCommon(workloadFactory, memoryManager);
 };
 
-namespace
-{
-
-template <typename T, std::size_t InputDim, std::size_t OutputDim>
-LayerTestResult<T, OutputDim> MeanTestHelper(
-    armnn::IWorkloadFactory& workloadFactory,
-    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
-    const unsigned int* inputShape,
-    const std::vector<T>& inputData,
-    const std::vector<unsigned int>& axis,
-    bool keepDims,
-    const unsigned int* outputShape,
-    const std::vector<T>& outputData,
-    float scale = 1.0f,
-    int32_t offset = 0)
-{
-    auto dataType = (std::is_same<T, uint8_t>::value ? armnn::DataType::QuantisedAsymm8 : armnn::DataType::Float32);
-
-    armnn::TensorInfo inputTensorInfo(InputDim, inputShape, dataType);
-    armnn::TensorInfo outputTensorInfo(OutputDim, outputShape, dataType);
-
-    inputTensorInfo.SetQuantizationScale(scale);
-    inputTensorInfo.SetQuantizationOffset(offset);
-
-    outputTensorInfo.SetQuantizationScale(scale);
-    outputTensorInfo.SetQuantizationOffset(offset);
-
-    auto input = MakeTensor<T, InputDim>(inputTensorInfo, inputData);
-
-    LayerTestResult<T, OutputDim> result(outputTensorInfo);
-    result.outputExpected = MakeTensor<T, OutputDim>(outputTensorInfo, outputData);
-
-    std::unique_ptr<armnn::ITensorHandle> inputHandle = workloadFactory.CreateTensorHandle(inputTensorInfo);
-    std::unique_ptr<armnn::ITensorHandle> outputHandle = workloadFactory.CreateTensorHandle(outputTensorInfo);
-
-    armnn::MeanQueueDescriptor data;
-    data.m_Parameters.m_Axis = axis;
-    data.m_Parameters.m_KeepDims = keepDims;
-    armnn::WorkloadInfo info;
-    AddInputToWorkload(data,  info, inputTensorInfo, inputHandle.get());
-    AddOutputToWorkload(data, info, outputTensorInfo, outputHandle.get());
-
-    std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateMean(data, info);
-
-    inputHandle->Allocate();
-    outputHandle->Allocate();
-
-    CopyDataToITensorHandle(inputHandle.get(), input.origin());
-
-    workload->PostAllocationConfigure();
-    workload->Execute();
-
-    CopyDataFromITensorHandle(result.output.origin(), outputHandle.get());
-
-    return result;
-}
-
-} // anonymous namespace
-
-LayerTestResult<uint8_t, 1> MeanUint8SimpleTest(
-    armnn::IWorkloadFactory& workloadFactory,
-    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
-{
-    const unsigned int inputShape[] = { 3, 2 };
-    const unsigned int outputShape[] = { 1 };
-
-    std::vector<uint8_t> input({ 1, 1, 2, 2, 3, 3 });
-    std::vector<uint8_t> output({ 2 });
-
-    return MeanTestHelper<uint8_t, 2, 1>(
-        workloadFactory, memoryManager, inputShape, input, {}, false, outputShape, output);
-}
-
-LayerTestResult<uint8_t, 3> MeanUint8SimpleAxisTest(
-    armnn::IWorkloadFactory& workloadFactory,
-    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
-{
-    const unsigned int inputShape[] = { 1, 1, 3, 2 };
-    const unsigned int outputShape[] = { 1, 1, 2 };
-
-    std::vector<uint8_t> input({ 1, 1, 2, 2, 3, 3 });
-    std::vector<uint8_t> output({ 2, 2 });
-
-    return MeanTestHelper<uint8_t, 4, 3>(
-        workloadFactory, memoryManager, inputShape, input, { 2 }, false, outputShape, output);
-}
-
-LayerTestResult<uint8_t, 4> MeanUint8KeepDimsTest(
-    armnn::IWorkloadFactory& workloadFactory,
-    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
-{
-    const unsigned int inputShape[] = { 1, 1, 3, 2 };
-    const unsigned int outputShape[] = { 1, 1, 1, 2 };
-
-    std::vector<uint8_t> input({ 1, 1, 2, 2, 3, 3 });
-    std::vector<uint8_t> output({ 2, 2 });
-
-    return MeanTestHelper<uint8_t, 4, 4>(
-        workloadFactory, memoryManager, inputShape, input, { 2 }, true, outputShape, output);
-}
-
-LayerTestResult<uint8_t, 4> MeanUint8MultipleDimsTest(
-    armnn::IWorkloadFactory& workloadFactory,
-    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
-{
-    const unsigned int inputShape[] = { 2, 3, 1, 2 };
-    const unsigned int outputShape[] = { 1, 3, 1, 1 };
-
-    std::vector<uint8_t> input({ 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6 });
-    std::vector<uint8_t> output({ 1, 3, 5 });
-
-    return MeanTestHelper<uint8_t, 4, 4>(
-        workloadFactory, memoryManager, inputShape, input, { 0, 3 }, true, outputShape, output);
-}
-
-LayerTestResult<uint8_t, 1> MeanVtsUint8Test(
-    armnn::IWorkloadFactory& workloadFactory,
-    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
-{
-    const unsigned int inputShape[] = { 4, 3, 2 };
-    const unsigned int outputShape[] = { 2 };
-
-    std::vector<uint8_t> input({ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
-                                 24 });
-    std::vector<uint8_t> output({ 12, 13 });
-
-    return MeanTestHelper<uint8_t, 3, 1>(workloadFactory, memoryManager,
-                                         inputShape, input, { 0, 1 }, false, outputShape,
-                                         output, 0.8f, 5);
-}
-
-LayerTestResult<float, 1> MeanFloatSimpleTest(
-    armnn::IWorkloadFactory& workloadFactory,
-    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
-{
-    const unsigned int inputShape[] = { 3, 2 };
-    const unsigned int outputShape[] = { 1 };
-
-    std::vector<float> input({ 1.0f, 1.0f, 2.0f, 2.0f, 3.0f, 3.0f });
-    std::vector<float> output({ 2.0f });
-
-    return MeanTestHelper<float, 2, 1>(
-        workloadFactory, memoryManager, inputShape, input, {}, false, outputShape, output);
-}
-
-LayerTestResult<float, 3> MeanFloatSimpleAxisTest(
-    armnn::IWorkloadFactory& workloadFactory,
-    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
-{
-    const unsigned int inputShape[] = { 2, 3, 1, 2 };
-    const unsigned int outputShape[] = { 3, 1, 2 };
-
-    std::vector<float> input({ 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f });
-    std::vector<float> output({ 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f });
-
-    return MeanTestHelper<float, 4, 3>(
-        workloadFactory, memoryManager, inputShape, input, { 0 }, false, outputShape, output);
-}
-
-LayerTestResult<float, 4> MeanFloatKeepDimsTest(
-    armnn::IWorkloadFactory& workloadFactory,
-    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
-{
-    const unsigned int inputShape[] = { 1, 1, 3, 2 };
-    const unsigned int outputShape[] = { 1, 1, 1, 2 };
-
-    std::vector<float> input({ 1.0f, 1.0f, 2.0f, 2.0f, 3.0f, 3.0f });
-    std::vector<float> output({ 2.0f, 2.0f });
-
-    return MeanTestHelper<float, 4, 4>(
-        workloadFactory, memoryManager, inputShape, input, { 2 }, true, outputShape, output);
-}
-
-LayerTestResult<float, 4> MeanFloatMultipleDimsTest(
-    armnn::IWorkloadFactory& workloadFactory,
-    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
-{
-    const unsigned int inputShape[] = { 2, 3, 1, 2 };
-    const unsigned int outputShape[] = { 1, 3, 1, 1 };
-
-    std::vector<float> input({ 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f });
-    std::vector<float> output({ 1.5f, 3.5f, 5.5f });
-
-    return MeanTestHelper<float, 4, 4>(
-        workloadFactory, memoryManager, inputShape, input, { 0, 3 }, true, outputShape, output);
-}
-
-LayerTestResult<float, 1> MeanVtsFloat1Test(
-    armnn::IWorkloadFactory& workloadFactory,
-    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
-{
-    const unsigned int inputShape[] = { 4, 3, 2 };
-    const unsigned int outputShape[] = { 2 };
-
-    std::vector<float> input({ 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f,
-                               15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f });
-    std::vector<float> output({ 12.0f, 13.0f });
-
-    return MeanTestHelper<float, 3, 1>(
-        workloadFactory, memoryManager, inputShape, input, { 0, 1 }, false, outputShape, output);
-}
-
-LayerTestResult<float, 3> MeanVtsFloat2Test(
-    armnn::IWorkloadFactory& workloadFactory,
-    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
-{
-    const unsigned int inputShape[] = { 4, 3, 2 };
-    const unsigned int outputShape[] = { 1, 3, 1 };
-
-    std::vector<float> input({ 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f,
-                               15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f });
-    std::vector<float> output({ 10.5f, 12.5f, 14.5f });
-
-    return MeanTestHelper<float, 3, 3>(
-        workloadFactory, memoryManager, inputShape, input, { 0, 2 }, true, outputShape, output);
-}
-
-LayerTestResult<float, 3> MeanVtsFloat3Test(
-    armnn::IWorkloadFactory& workloadFactory,
-    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
-{
-    const unsigned int inputShape[] = { 1, 2, 2, 1 };
-    const unsigned int outputShape[] = { 1, 2, 1 };
-
-    std::vector<float> input({ 1.0f, 2.0f, 3.0f, 4.0f });
-    std::vector<float> output({ 1.5f, 3.5f });
-
-    return MeanTestHelper<float, 4, 3>(
-        workloadFactory, memoryManager, inputShape, input, { 2 }, false, outputShape, output);
-}
-
 LayerTestResult<float, 4> AdditionAfterMaxPoolTest(
     armnn::IWorkloadFactory& workloadFactory,
     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
diff --git a/src/backends/backendsCommon/test/LayerTests.hpp b/src/backends/backendsCommon/test/LayerTests.hpp
index 3eed189..fab1ad8 100644
--- a/src/backends/backendsCommon/test/LayerTests.hpp
+++ b/src/backends/backendsCommon/test/LayerTests.hpp
@@ -1448,51 +1448,38 @@
     armnn::IWorkloadFactory& workloadFactory,
     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
 
-LayerTestResult<uint8_t, 1> MeanUint8SimpleTest(
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
+LayerTestResult<T, 1> MeanSimpleTest(
     armnn::IWorkloadFactory& workloadFactory,
     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
 
-LayerTestResult<uint8_t, 3> MeanUint8SimpleAxisTest(
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
+LayerTestResult<T, 3> MeanSimpleAxisTest(
     armnn::IWorkloadFactory& workloadFactory,
     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
 
-LayerTestResult<uint8_t, 4> MeanUint8KeepDimsTest(
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
+LayerTestResult<T, 4> MeanKeepDimsTest(
     armnn::IWorkloadFactory& workloadFactory,
     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
 
-LayerTestResult<uint8_t, 4> MeanUint8MultipleDimsTest(
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
+LayerTestResult<T, 4> MeanMultipleDimsTest(
     armnn::IWorkloadFactory& workloadFactory,
     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
 
-LayerTestResult<uint8_t, 1> MeanVtsUint8Test(
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
+LayerTestResult<T, 1> MeanVts1Test(
     armnn::IWorkloadFactory& workloadFactory,
     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
 
-LayerTestResult<float, 1> MeanFloatSimpleTest(
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
+LayerTestResult<T, 3> MeanVts2Test(
     armnn::IWorkloadFactory& workloadFactory,
     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
 
-LayerTestResult<float, 3> MeanFloatSimpleAxisTest(
-    armnn::IWorkloadFactory& workloadFactory,
-    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
-
-LayerTestResult<float, 4> MeanFloatKeepDimsTest(
-    armnn::IWorkloadFactory& workloadFactory,
-    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
-
-LayerTestResult<float, 4> MeanFloatMultipleDimsTest(
-    armnn::IWorkloadFactory& workloadFactory,
-    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
-
-LayerTestResult<float, 1> MeanVtsFloat1Test(
-    armnn::IWorkloadFactory& workloadFactory,
-    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
-
-LayerTestResult<float, 3> MeanVtsFloat2Test(
-    armnn::IWorkloadFactory& workloadFactory,
-    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
-
-LayerTestResult<float, 3> MeanVtsFloat3Test(
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
+LayerTestResult<T, 3> MeanVts3Test(
     armnn::IWorkloadFactory& workloadFactory,
     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
 
@@ -2912,4 +2899,164 @@
 ResizeBilinearMagTest<armnn::DataType::QuantisedAsymm8>(
         armnn::IWorkloadFactory& workloadFactory,
         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
-        const armnn::DataLayout dataLayout);
\ No newline at end of file
+        const armnn::DataLayout dataLayout);
+
+template<armnn::DataType ArmnnType, typename T, std::size_t InputDim, std::size_t OutputDim>
+LayerTestResult<T, OutputDim> MeanTestHelper(
+        armnn::IWorkloadFactory& workloadFactory,
+        const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+        const unsigned int* inputShape,
+        const std::vector<float>& inputData,
+        const std::vector<unsigned int>& axis,
+        bool keepDims,
+        const unsigned int* outputShape,
+        const std::vector<float>& outputData,
+        float scale = 1.0f,
+        int32_t offset = 0)
+{
+    armnn::TensorInfo inputTensorInfo(InputDim, inputShape, ArmnnType);
+    armnn::TensorInfo outputTensorInfo(OutputDim, outputShape, ArmnnType);
+
+    inputTensorInfo.SetQuantizationScale(scale);
+    inputTensorInfo.SetQuantizationOffset(offset);
+
+    outputTensorInfo.SetQuantizationScale(scale);
+    outputTensorInfo.SetQuantizationOffset(offset);
+
+    auto input = MakeTensor<T, InputDim>(inputTensorInfo, ConvertToDataType<ArmnnType>(inputData, inputTensorInfo));
+
+    LayerTestResult<T, OutputDim> result(outputTensorInfo);
+    result.outputExpected = MakeTensor<T, OutputDim>(
+            outputTensorInfo, ConvertToDataType<ArmnnType>(outputData, outputTensorInfo));
+
+    std::unique_ptr<armnn::ITensorHandle> inputHandle = workloadFactory.CreateTensorHandle(inputTensorInfo);
+    std::unique_ptr<armnn::ITensorHandle> outputHandle = workloadFactory.CreateTensorHandle(outputTensorInfo);
+
+    armnn::MeanQueueDescriptor data;
+    data.m_Parameters.m_Axis = axis;
+    data.m_Parameters.m_KeepDims = keepDims;
+    armnn::WorkloadInfo info;
+    AddInputToWorkload(data,  info, inputTensorInfo, inputHandle.get());
+    AddOutputToWorkload(data, info, outputTensorInfo, outputHandle.get());
+
+    std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateMean(data, info);
+
+    inputHandle->Allocate();
+    outputHandle->Allocate();
+
+    CopyDataToITensorHandle(inputHandle.get(), input.origin());
+
+    workload->PostAllocationConfigure();
+    workload->Execute();
+
+    CopyDataFromITensorHandle(result.output.origin(), outputHandle.get());
+
+    return result;
+}
+
+template<armnn::DataType ArmnnType, typename T>
+LayerTestResult<T, 1> MeanSimpleTest(
+        armnn::IWorkloadFactory& workloadFactory,
+        const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
+{
+    const unsigned int inputShape[] = { 3, 2 };
+    const unsigned int outputShape[] = { 1 };
+
+    std::vector<float> input({ 1.5f, 1.5f, 2.5f, 2.5f, 3.5f, 3.5f });
+    std::vector<float> output({ 2.5f });
+
+    return MeanTestHelper<ArmnnType, T, 2, 1>(
+            workloadFactory, memoryManager, inputShape, input, {}, false, outputShape, output);
+}
+
+template<armnn::DataType ArmnnType, typename T>
+LayerTestResult<T, 3> MeanSimpleAxisTest(
+        armnn::IWorkloadFactory& workloadFactory,
+        const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
+{
+    const unsigned int inputShape[] = { 2, 3, 1, 2 };
+    const unsigned int outputShape[] = { 3, 1, 2 };
+
+    std::vector<float> input({ 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f });
+    std::vector<float> output({ 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f });
+
+    return MeanTestHelper<ArmnnType, T, 4, 3>(
+            workloadFactory, memoryManager, inputShape, input, { 0 }, false, outputShape, output);
+}
+
+template<armnn::DataType ArmnnType, typename T>
+LayerTestResult<T, 4> MeanKeepDimsTest(
+        armnn::IWorkloadFactory& workloadFactory,
+        const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
+{
+    const unsigned int inputShape[] = { 1, 1, 3, 2 };
+    const unsigned int outputShape[] = { 1, 1, 1, 2 };
+
+    std::vector<float> input({ 1.5f, 1.5f, 2.5f, 2.5f, 3.5f, 3.5f });
+    std::vector<float> output({ 2.5f, 2.5f });
+
+    return MeanTestHelper<ArmnnType, T, 4, 4>(
+            workloadFactory, memoryManager, inputShape, input, { 2 }, true, outputShape, output);
+}
+
+template<armnn::DataType ArmnnType, typename T>
+LayerTestResult<T, 4> MeanMultipleDimsTest(
+        armnn::IWorkloadFactory& workloadFactory,
+        const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
+{
+    const unsigned int inputShape[] = { 2, 3, 1, 2 };
+    const unsigned int outputShape[] = { 1, 3, 1, 1 };
+
+    std::vector<float> input({ 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5 });
+    std::vector<float> output({ 2.0f, 4.0f, 6.0f });
+
+    return MeanTestHelper<ArmnnType, T, 4, 4>(
+            workloadFactory, memoryManager, inputShape, input, { 0, 3 }, true, outputShape, output);
+}
+
+template<armnn::DataType ArmnnType, typename T>
+LayerTestResult<T, 1> MeanVts1Test(
+        armnn::IWorkloadFactory& workloadFactory,
+        const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
+{
+    const unsigned int inputShape[] = { 4, 3, 2 };
+    const unsigned int outputShape[] = { 2 };
+
+    std::vector<float> input({ 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f,
+                               15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f });
+    std::vector<float> output({ 12.0f, 13.0f });
+
+    return MeanTestHelper<ArmnnType, T, 3, 1>(
+            workloadFactory, memoryManager, inputShape, input, { 0, 1 }, false, outputShape, output);
+}
+
+template<armnn::DataType ArmnnType, typename T>
+LayerTestResult<T, 3> MeanVts2Test(
+        armnn::IWorkloadFactory& workloadFactory,
+        const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
+{
+    const unsigned int inputShape[] = { 4, 3, 2 };
+    const unsigned int outputShape[] = { 1, 3, 1 };
+
+    std::vector<float> input({ 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f,
+                               15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f });
+    std::vector<float> output({ 10.5f, 12.5f, 14.5f });
+
+    return MeanTestHelper<ArmnnType, T, 3, 3>(
+            workloadFactory, memoryManager, inputShape, input, { 0, 2 }, true, outputShape, output);
+}
+
+template<armnn::DataType ArmnnType, typename T>
+LayerTestResult<T, 3> MeanVts3Test(
+        armnn::IWorkloadFactory& workloadFactory,
+        const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
+{
+    const unsigned int inputShape[] = { 1, 2, 2, 1 };
+    const unsigned int outputShape[] = { 1, 2, 1 };
+
+    std::vector<float> input({ 1.0f, 2.0f, 3.0f, 4.0f });
+    std::vector<float> output({ 1.5f, 3.5f });
+
+    return MeanTestHelper<ArmnnType, T, 4, 3>(
+            workloadFactory, memoryManager, inputShape, input, { 2 }, false, outputShape, output);
+}
diff --git a/src/backends/cl/test/ClLayerTests.cpp b/src/backends/cl/test/ClLayerTests.cpp
index fee980c..164a42a 100644
--- a/src/backends/cl/test/ClLayerTests.cpp
+++ b/src/backends/cl/test/ClLayerTests.cpp
@@ -357,19 +357,21 @@
 ARMNN_AUTO_TEST_CASE(MaximumBroadcast1DVectorUint8, MaximumBroadcast1DVectorUint8Test)
 
 // Mean
-ARMNN_AUTO_TEST_CASE(MeanUint8Simple, MeanUint8SimpleTest)
-ARMNN_AUTO_TEST_CASE(MeanUint8SimpleAxis, MeanUint8SimpleAxisTest)
-ARMNN_AUTO_TEST_CASE(MeanUint8KeepDims, MeanUint8KeepDimsTest)
-ARMNN_AUTO_TEST_CASE(MeanUint8MultipleDims, MeanUint8MultipleDimsTest)
-ARMNN_AUTO_TEST_CASE(MeanVtsUint8, MeanVtsUint8Test)
+ARMNN_AUTO_TEST_CASE(MeanSimpleFloat32, MeanSimpleTest<armnn::DataType::Float32>)
+ARMNN_AUTO_TEST_CASE(MeanSimpleAxisFloat32, MeanSimpleAxisTest<armnn::DataType::Float32>)
+ARMNN_AUTO_TEST_CASE(MeanKeepDimsFloat32, MeanKeepDimsTest<armnn::DataType::Float32>)
+ARMNN_AUTO_TEST_CASE(MeanMultipleDimsFloat32, MeanMultipleDimsTest<armnn::DataType::Float32>)
+ARMNN_AUTO_TEST_CASE(MeanVts1Float32, MeanVts1Test<armnn::DataType::Float32>)
+ARMNN_AUTO_TEST_CASE(MeanVts2Float32, MeanVts2Test<armnn::DataType::Float32>)
+ARMNN_AUTO_TEST_CASE(MeanVts3Float32, MeanVts3Test<armnn::DataType::Float32>)
 
-ARMNN_AUTO_TEST_CASE(MeanFloatSimple, MeanFloatSimpleTest)
-ARMNN_AUTO_TEST_CASE(MeanFloatSimpleAxis, MeanFloatSimpleAxisTest)
-ARMNN_AUTO_TEST_CASE(MeanFloatKeepDims, MeanFloatKeepDimsTest)
-ARMNN_AUTO_TEST_CASE(MeanFloatMultipleDims, MeanFloatMultipleDimsTest)
-ARMNN_AUTO_TEST_CASE(MeanVtsFloat1, MeanVtsFloat1Test)
-ARMNN_AUTO_TEST_CASE(MeanVtsFloat2, MeanVtsFloat2Test)
-ARMNN_AUTO_TEST_CASE(MeanVtsFloat3, MeanVtsFloat3Test)
+ARMNN_AUTO_TEST_CASE(MeanSimpleQuantisedAsymm8, MeanSimpleTest<armnn::DataType::QuantisedAsymm8>)
+ARMNN_AUTO_TEST_CASE(MeanSimpleAxisQuantisedAsymm8, MeanSimpleAxisTest<armnn::DataType::QuantisedAsymm8>)
+ARMNN_AUTO_TEST_CASE(MeanKeepDimsQuantisedAsymm8, MeanKeepDimsTest<armnn::DataType::QuantisedAsymm8>)
+ARMNN_AUTO_TEST_CASE(MeanMultipleDimsQuantisedAsymm8, MeanMultipleDimsTest<armnn::DataType::QuantisedAsymm8>)
+ARMNN_AUTO_TEST_CASE(MeanVts1QuantisedAsymm8, MeanVts1Test<armnn::DataType::QuantisedAsymm8>)
+ARMNN_AUTO_TEST_CASE(MeanVts2QuantisedAsymm8, MeanVts2Test<armnn::DataType::QuantisedAsymm8>)
+ARMNN_AUTO_TEST_CASE(MeanVts3QuantisedAsymm8, MeanVts3Test<armnn::DataType::QuantisedAsymm8>)
 
 // Minimum
 ARMNN_AUTO_TEST_CASE(MinimumBroadcast1Element1, MinimumBroadcast1ElementTest1)
diff --git a/src/backends/neon/test/NeonLayerTests.cpp b/src/backends/neon/test/NeonLayerTests.cpp
index 4e719d2..af9db52 100644
--- a/src/backends/neon/test/NeonLayerTests.cpp
+++ b/src/backends/neon/test/NeonLayerTests.cpp
@@ -438,19 +438,21 @@
                      LstmLayerFloat32NoCifgWithPeepholeWithProjectionTest)
 
 // Mean
-ARMNN_AUTO_TEST_CASE(MeanUint8Simple, MeanUint8SimpleTest)
-ARMNN_AUTO_TEST_CASE(MeanUint8SimpleAxis, MeanUint8SimpleAxisTest)
-ARMNN_AUTO_TEST_CASE(MeanUint8KeepDims, MeanUint8KeepDimsTest)
-ARMNN_AUTO_TEST_CASE(MeanUint8MultipleDims, MeanUint8MultipleDimsTest)
-ARMNN_AUTO_TEST_CASE(MeanVtsUint8, MeanVtsUint8Test)
+ARMNN_AUTO_TEST_CASE(MeanSimpleFloat32, MeanSimpleTest<armnn::DataType::Float32>)
+ARMNN_AUTO_TEST_CASE(MeanSimpleAxisFloat32, MeanSimpleAxisTest<armnn::DataType::Float32>)
+ARMNN_AUTO_TEST_CASE(MeanKeepDimsFloat32, MeanKeepDimsTest<armnn::DataType::Float32>)
+ARMNN_AUTO_TEST_CASE(MeanMultipleDimsFloat32, MeanMultipleDimsTest<armnn::DataType::Float32>)
+ARMNN_AUTO_TEST_CASE(MeanVts1Float32, MeanVts1Test<armnn::DataType::Float32>)
+ARMNN_AUTO_TEST_CASE(MeanVts2Float32, MeanVts2Test<armnn::DataType::Float32>)
+ARMNN_AUTO_TEST_CASE(MeanVts3Float32, MeanVts3Test<armnn::DataType::Float32>)
 
-ARMNN_AUTO_TEST_CASE(MeanFloatSimple, MeanFloatSimpleTest)
-ARMNN_AUTO_TEST_CASE(MeanFloatSimpleAxis, MeanFloatSimpleAxisTest)
-ARMNN_AUTO_TEST_CASE(MeanFloatKeepDims, MeanFloatKeepDimsTest)
-ARMNN_AUTO_TEST_CASE(MeanFloatMultipleDims, MeanFloatMultipleDimsTest)
-ARMNN_AUTO_TEST_CASE(MeanVtsFloat1, MeanVtsFloat1Test)
-ARMNN_AUTO_TEST_CASE(MeanVtsFloat2, MeanVtsFloat2Test)
-ARMNN_AUTO_TEST_CASE(MeanVtsFloat3, MeanVtsFloat3Test)
+ARMNN_AUTO_TEST_CASE(MeanSimpleQuantisedAsymm8, MeanSimpleTest<armnn::DataType::QuantisedAsymm8>)
+ARMNN_AUTO_TEST_CASE(MeanSimpleAxisQuantisedAsymm8, MeanSimpleAxisTest<armnn::DataType::QuantisedAsymm8>)
+ARMNN_AUTO_TEST_CASE(MeanKeepDimsQuantisedAsymm8, MeanKeepDimsTest<armnn::DataType::QuantisedAsymm8>)
+ARMNN_AUTO_TEST_CASE(MeanMultipleDimsQuantisedAsymm8, MeanMultipleDimsTest<armnn::DataType::QuantisedAsymm8>)
+ARMNN_AUTO_TEST_CASE(MeanVts1QuantisedAsymm8, MeanVts1Test<armnn::DataType::QuantisedAsymm8>)
+ARMNN_AUTO_TEST_CASE(MeanVts2QuantisedAsymm8, MeanVts2Test<armnn::DataType::QuantisedAsymm8>)
+ARMNN_AUTO_TEST_CASE(MeanVts3QuantisedAsymm8, MeanVts3Test<armnn::DataType::QuantisedAsymm8>)
 
 // Max
 ARMNN_AUTO_TEST_CASE(SimpleMaximum, MaximumSimpleTest)
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index cf1814e..402bd66 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -47,6 +47,21 @@
 
 } // anonymous namespace
 
+namespace
+{
+
+std::string CreateIncorrectDimensionsErrorMsg(unsigned int expected,
+                                              unsigned int actual,
+                                              std::string& layerStr,
+                                              std::string& tensorName)
+{
+    std::string errorMsg = "Reference " + layerStr + ": Expected " + std::to_string(expected) + " dimensions but got" +
+                           " " + std::to_string(actual) + " dimensions instead, for the '" + tensorName + "' tensor.";
+
+    return errorMsg;
+}
+
+} // anonymous namespace
 
 namespace
 {
@@ -177,6 +192,15 @@
         }
     }
 };
+
+struct TensorNumDimensionsAreCorrect : public Rule
+{
+    TensorNumDimensionsAreCorrect(const TensorInfo& info, unsigned int expectedNumDimensions)
+    {
+        m_Res = info.GetNumDimensions() == expectedNumDimensions;
+    }
+};
+
 } // namespace
 
 
@@ -874,12 +898,58 @@
                                       const MeanDescriptor& descriptor,
                                       Optional<std::string&> reasonIfUnsupported) const
 {
-    ignore_unused(output);
-    ignore_unused(descriptor);
-    return IsSupportedForDataTypeRef(reasonIfUnsupported,
-                                     input.GetDataType(),
-                                     &TrueFunc<>,
-                                     &TrueFunc<>);
+    bool supported = true;
+    std::string meanLayerStr = "Mean";
+    std::string outputTensorStr = "output";
+
+    std::array<DataType,2> supportedTypes =
+    {
+        DataType::Float32,
+        DataType::QuantisedAsymm8
+    };
+
+    supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
+                                  "Reference Mean: input type not supported.");
+
+    supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
+                                  "Reference Mean: input and output types are mismatched");
+
+    if (descriptor.m_KeepDims)
+    {
+        supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, input.GetNumDimensions()),
+                                      reasonIfUnsupported,
+                                      CreateIncorrectDimensionsErrorMsg(input.GetNumDimensions(),
+                                                                        output.GetNumDimensions(),
+                                                                        meanLayerStr, outputTensorStr).data());
+    }
+    else if (descriptor.m_Axis.empty())
+    {
+        supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
+                                      reasonIfUnsupported,
+                                      CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
+                                                                        meanLayerStr, outputTensorStr).data());
+    }
+    else
+    {
+        auto outputDim = input.GetNumDimensions() - boost::numeric_cast<unsigned int>(descriptor.m_Axis.size());
+
+        if (outputDim > 0)
+        {
+            supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, outputDim),
+                                          reasonIfUnsupported,
+                                          CreateIncorrectDimensionsErrorMsg(outputDim, output.GetNumDimensions(),
+                                                                            meanLayerStr, outputTensorStr).data());
+        }
+        else
+        {
+            supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
+                                          reasonIfUnsupported,
+                                          CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
+                                                                            meanLayerStr, outputTensorStr).data());
+        }
+    }
+
+    return supported;
 }
 
 bool RefLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp
index 728e605..4467bd4 100644
--- a/src/backends/reference/RefWorkloadFactory.cpp
+++ b/src/backends/reference/RefWorkloadFactory.cpp
@@ -353,7 +353,11 @@
 std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateMean(
     const MeanQueueDescriptor& descriptor, const WorkloadInfo& info) const
 {
-    return MakeWorkload<RefMeanFloat32Workload, RefMeanUint8Workload>(descriptor, info);
+    if (IsFloat16(info))
+    {
+        return MakeWorkload<NullWorkload, NullWorkload>(descriptor, info);
+    }
+    return  std::make_unique<RefMeanWorkload>(descriptor, info);
 }
 
 std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateMinimum(
diff --git a/src/backends/reference/backend.mk b/src/backends/reference/backend.mk
index c4a0c76..ecd2812 100644
--- a/src/backends/reference/backend.mk
+++ b/src/backends/reference/backend.mk
@@ -45,8 +45,7 @@
         workloads/RefGatherWorkload.cpp \
         workloads/RefL2NormalizationWorkload.cpp \
         workloads/RefLstmWorkload.cpp \
-        workloads/RefMeanFloat32Workload.cpp \
-        workloads/RefMeanUint8Workload.cpp \
+        workloads/RefMeanWorkload.cpp \
         workloads/RefNormalizationWorkload.cpp \
         workloads/RefPadWorkload.cpp \
         workloads/RefPermuteWorkload.cpp \
diff --git a/src/backends/reference/test/RefLayerSupportTests.cpp b/src/backends/reference/test/RefLayerSupportTests.cpp
index 2c7e17da..0d99b3e 100644
--- a/src/backends/reference/test/RefLayerSupportTests.cpp
+++ b/src/backends/reference/test/RefLayerSupportTests.cpp
@@ -14,6 +14,7 @@
 #include <backendsCommon/test/IsLayerSupportedTestImpl.hpp>
 
 #include <boost/test/unit_test.hpp>
+#include <boost/algorithm/string/trim.hpp>
 
 #include <string>
 
@@ -130,4 +131,28 @@
     BOOST_CHECK_EQUAL(reasonIfUnsupported, "Layer is not supported with float32 data type output");
 }
 
+BOOST_AUTO_TEST_CASE(IsLayerSupportedMeanDimensionsReference)
+{
+    std::string reasonIfUnsupported;
+
+    bool result = IsMeanLayerSupportedTests<armnn::RefWorkloadFactory,
+            armnn::DataType::Float32, armnn::DataType::Float32>(reasonIfUnsupported);
+
+    BOOST_CHECK(result);
+}
+
+BOOST_AUTO_TEST_CASE(IsLayerNotSupportedMeanDimensionsReference)
+{
+    std::string reasonIfUnsupported;
+
+    bool result = IsMeanLayerNotSupportedTests<armnn::RefWorkloadFactory,
+            armnn::DataType::Float32, armnn::DataType::Float32>(reasonIfUnsupported);
+
+    BOOST_CHECK(!result);
+
+    boost::algorithm::trim(reasonIfUnsupported);
+    BOOST_CHECK_EQUAL(reasonIfUnsupported,
+                      "Reference Mean: Expected 4 dimensions but got 2 dimensions instead, for the 'output' tensor.");
+}
+
 BOOST_AUTO_TEST_SUITE_END()
diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp
index 7ff6d1b..c2cda8e 100644
--- a/src/backends/reference/test/RefLayerTests.cpp
+++ b/src/backends/reference/test/RefLayerTests.cpp
@@ -607,19 +607,21 @@
 ARMNN_AUTO_TEST_CASE(SimpleConvertFp32ToFp16, SimpleConvertFp32ToFp16Test)
 
 // Mean
-ARMNN_AUTO_TEST_CASE(MeanUint8Simple, MeanUint8SimpleTest)
-ARMNN_AUTO_TEST_CASE(MeanUint8SimpleAxis, MeanUint8SimpleAxisTest)
-ARMNN_AUTO_TEST_CASE(MeanUint8KeepDims, MeanUint8KeepDimsTest)
-ARMNN_AUTO_TEST_CASE(MeanUint8MultipleDims, MeanUint8MultipleDimsTest)
-ARMNN_AUTO_TEST_CASE(MeanVtsUint8, MeanVtsUint8Test)
+ARMNN_AUTO_TEST_CASE(MeanSimpleFloat32, MeanSimpleTest<armnn::DataType::Float32>)
+ARMNN_AUTO_TEST_CASE(MeanSimpleAxisFloat32, MeanSimpleAxisTest<armnn::DataType::Float32>)
+ARMNN_AUTO_TEST_CASE(MeanKeepDimsFloat32, MeanKeepDimsTest<armnn::DataType::Float32>)
+ARMNN_AUTO_TEST_CASE(MeanMultipleDimsFloat32, MeanMultipleDimsTest<armnn::DataType::Float32>)
+ARMNN_AUTO_TEST_CASE(MeanVts1Float32, MeanVts1Test<armnn::DataType::Float32>)
+ARMNN_AUTO_TEST_CASE(MeanVts2Float32, MeanVts2Test<armnn::DataType::Float32>)
+ARMNN_AUTO_TEST_CASE(MeanVts3Float32, MeanVts3Test<armnn::DataType::Float32>)
 
-ARMNN_AUTO_TEST_CASE(MeanFloatSimple, MeanFloatSimpleTest)
-ARMNN_AUTO_TEST_CASE(MeanFloatSimpleAxis, MeanFloatSimpleAxisTest)
-ARMNN_AUTO_TEST_CASE(MeanFloatKeepDims, MeanFloatKeepDimsTest)
-ARMNN_AUTO_TEST_CASE(MeanFloatMultipleDims, MeanFloatMultipleDimsTest)
-ARMNN_AUTO_TEST_CASE(MeanVtsFloat1, MeanVtsFloat1Test)
-ARMNN_AUTO_TEST_CASE(MeanVtsFloat2, MeanVtsFloat2Test)
-ARMNN_AUTO_TEST_CASE(MeanVtsFloat3, MeanVtsFloat3Test)
+ARMNN_AUTO_TEST_CASE(MeanSimpleQuantisedAsymm8, MeanSimpleTest<armnn::DataType::QuantisedAsymm8>)
+ARMNN_AUTO_TEST_CASE(MeanSimpleAxisQuantisedAsymm8, MeanSimpleAxisTest<armnn::DataType::QuantisedAsymm8>)
+ARMNN_AUTO_TEST_CASE(MeanKeepDimsQuantisedAsymm8, MeanKeepDimsTest<armnn::DataType::QuantisedAsymm8>)
+ARMNN_AUTO_TEST_CASE(MeanMultipleDimsQuantisedAsymm8, MeanMultipleDimsTest<armnn::DataType::QuantisedAsymm8>)
+ARMNN_AUTO_TEST_CASE(MeanVts1QuantisedAsymm8, MeanVts1Test<armnn::DataType::QuantisedAsymm8>)
+ARMNN_AUTO_TEST_CASE(MeanVts2QuantisedAsymm8, MeanVts2Test<armnn::DataType::QuantisedAsymm8>)
+ARMNN_AUTO_TEST_CASE(MeanVts3QuantisedAsymm8, MeanVts3Test<armnn::DataType::QuantisedAsymm8>)
 
 ARMNN_AUTO_TEST_CASE(AdditionAfterMaxPool, AdditionAfterMaxPoolTest)
 
diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt
index ebd3390..1ab38cc 100644
--- a/src/backends/reference/workloads/CMakeLists.txt
+++ b/src/backends/reference/workloads/CMakeLists.txt
@@ -119,10 +119,8 @@
     TensorBufferArrayView.hpp
     Mean.cpp
     Mean.hpp
-    RefMeanFloat32Workload.cpp
-    RefMeanFloat32Workload.hpp
-    RefMeanUint8Workload.cpp
-    RefMeanUint8Workload.hpp
+    RefMeanWorkload.cpp
+    RefMeanWorkload.hpp
 )
 
 add_library(armnnRefBackendWorkloads OBJECT ${armnnRefBackendWorkloads_sources})
diff --git a/src/backends/reference/workloads/Mean.cpp b/src/backends/reference/workloads/Mean.cpp
index 530aade..3ac3af9 100644
--- a/src/backends/reference/workloads/Mean.cpp
+++ b/src/backends/reference/workloads/Mean.cpp
@@ -36,10 +36,13 @@
     return (carry == 0);
 }
 
-std::size_t ReducedOutputOffset(const unsigned int numDims, const armnn::TensorShape& dims,
-                                std::vector<unsigned int>& index, const unsigned int numAxis,
-                                const std::vector<unsigned int>& axis) {
-    std::size_t offset = 0;
+unsigned int ReducedOutputOffset(const unsigned int numDims,
+                                 const armnn::TensorShape& dims,
+                                 std::vector<unsigned int>& index,
+                                 const unsigned int numAxis,
+                                 const std::vector<unsigned int>& axis)
+{
+    unsigned int offset = 0;
     for (unsigned int idx = 0; idx < numDims; ++idx)
     {
         bool isAxis = false;
@@ -56,7 +59,7 @@
         }
         if (!isAxis)
         {
-            offset = offset * boost::numeric_cast<size_t>(dims[idx]) + boost::numeric_cast<size_t>(index[idx]);
+            offset = offset * dims[idx] + index[idx];
         }
     }
     return offset;
@@ -68,8 +71,9 @@
 void Mean(const armnn::TensorInfo& inputInfo,
           const armnn::TensorInfo& outputInfo,
           const std::vector<unsigned int>& axis,
-          const float* inputData,
-          float* outputData) {
+          Decoder<float>& input,
+          Encoder<float>& output)
+{
 
     unsigned int inputNumDims = inputInfo.GetNumDimensions();
     unsigned int outputNumDims = outputInfo.GetNumDimensions();
@@ -78,16 +82,17 @@
     armnn::TensorShape inputDims = inputInfo.GetShape();
 
     // Initialise output data.
-    size_t numOutputs = 1;
+    unsigned int numOutputs = 1;
     for (unsigned int idx = 0; idx < outputNumDims; ++idx)
     {
-        numOutputs *= boost::numeric_cast<size_t>(outputDims[idx]);
+        numOutputs *= outputDims[idx];
     }
 
     std::vector<float> tempSum(numOutputs);
-    for (size_t idx = 0; idx < numOutputs; ++idx)
+    for (unsigned int idx = 0; idx < numOutputs; ++idx)
     {
-        outputData[idx] = 0.0f;
+        output[idx];
+        output.Set(0.0f);
         tempSum[idx] = 0.0f;
     }
 
@@ -106,30 +111,32 @@
           resolvedAxis.push_back(idx);
       }
     }
-    unsigned int numResolvedAxis = boost::numeric_cast<unsigned int>(resolvedAxis.size());
+    auto numResolvedAxis = boost::numeric_cast<unsigned int>(resolvedAxis.size());
 
     // Iterates through input_data and sum up the reduced axis.
     for (bool hasNext = true; hasNext; hasNext = NextIndex(inputNumDims, inputDims, tempIndex))
     {
-        size_t inputOffset = ReducedOutputOffset(inputNumDims, inputDims, tempIndex, 0, {});
-        size_t outputOffset = ReducedOutputOffset(inputNumDims, inputDims, tempIndex,
-                                                  numResolvedAxis, resolvedAxis);
-        tempSum[outputOffset] += inputData[inputOffset];
+        unsigned int inputOffset = ReducedOutputOffset(inputNumDims, inputDims, tempIndex, 0, {});
+        unsigned int outputOffset = ReducedOutputOffset(inputNumDims, inputDims, tempIndex,
+                                                        numResolvedAxis, resolvedAxis);
+        input[inputOffset];
+        tempSum[outputOffset] += input.Get();
     }
 
     // Takes average by num of elements added to get mean.
     size_t numElementsInAxis = 1;
     for (unsigned int idx = 0; idx < numResolvedAxis; ++idx)
     {
-        size_t current = boost::numeric_cast<size_t>(inputDims[resolvedAxis[idx]]);
+        unsigned int current = inputDims[resolvedAxis[idx]];
         BOOST_ASSERT(boost::numeric_cast<float>(current) <
               (std::numeric_limits<float>::max() / boost::numeric_cast<float>(numElementsInAxis)));
         numElementsInAxis *= current;
     }
     if (numElementsInAxis > 0) {
-        for (size_t idx = 0; idx < numOutputs; ++idx)
+        for (unsigned int idx = 0; idx < numOutputs; ++idx)
         {
-            outputData[idx] = tempSum[idx] / boost::numeric_cast<float>(numElementsInAxis);
+            output[idx];
+            output.Set(tempSum[idx] / boost::numeric_cast<float>(numElementsInAxis));
         }
     }
 }
diff --git a/src/backends/reference/workloads/Mean.hpp b/src/backends/reference/workloads/Mean.hpp
index 38c2e39..dfb0302 100644
--- a/src/backends/reference/workloads/Mean.hpp
+++ b/src/backends/reference/workloads/Mean.hpp
@@ -7,6 +7,7 @@
 
 #include "armnn/DescriptorsFwd.hpp"
 #include "armnn/Tensor.hpp"
+#include "BaseIterator.hpp"
 
 #include <vector>
 
@@ -15,7 +16,7 @@
 void Mean(const TensorInfo& inputInfo,
           const TensorInfo& outputInfo,
           const std::vector<unsigned int>& axis,
-          const float* inputData,
-          float* outputData);
+          Decoder<float>& input,
+          Encoder<float>& output);
 } //namespace armnn
 
diff --git a/src/backends/reference/workloads/RefMeanFloat32Workload.cpp b/src/backends/reference/workloads/RefMeanFloat32Workload.cpp
deleted file mode 100644
index a23906b..0000000
--- a/src/backends/reference/workloads/RefMeanFloat32Workload.cpp
+++ /dev/null
@@ -1,35 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#include "RefMeanFloat32Workload.hpp"
-
-#include "Mean.hpp"
-#include "RefWorkloadUtils.hpp"
-
-#include "Profiling.hpp"
-#include "vector"
-
-namespace armnn
-{
-
-RefMeanFloat32Workload::RefMeanFloat32Workload(const MeanQueueDescriptor& descriptor, const WorkloadInfo& info)
-  :Float32Workload<MeanQueueDescriptor>(descriptor, info) {}
-
-
-void RefMeanFloat32Workload::Execute() const
-{
-    ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefMeanFloat32Workload_Execute");
-
-    const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]);
-    const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
-    const float* inputData = GetInputTensorDataFloat(0, m_Data);
-    float* outputData = GetOutputTensorDataFloat(0, m_Data);
-
-    Mean(inputInfo, outputInfo, m_Data.m_Parameters.m_Axis, inputData, outputData);
-}
-
-} //namespace armnn
-
-
diff --git a/src/backends/reference/workloads/RefMeanFloat32Workload.hpp b/src/backends/reference/workloads/RefMeanFloat32Workload.hpp
deleted file mode 100644
index 153ebe1..0000000
--- a/src/backends/reference/workloads/RefMeanFloat32Workload.hpp
+++ /dev/null
@@ -1,22 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#pragma once
-
-#include "backendsCommon/Workload.hpp"
-#include "backendsCommon/WorkloadData.hpp"
-
-namespace armnn
-{
-
-
-class RefMeanFloat32Workload : public Float32Workload<MeanQueueDescriptor>
-{
-public:
-    explicit RefMeanFloat32Workload (const MeanQueueDescriptor& descriptor, const WorkloadInfo& info);
-    virtual void Execute() const override;
-};
-
-}//namespace armnn
diff --git a/src/backends/reference/workloads/RefMeanUint8Workload.cpp b/src/backends/reference/workloads/RefMeanUint8Workload.cpp
deleted file mode 100644
index 4ebffcf..0000000
--- a/src/backends/reference/workloads/RefMeanUint8Workload.cpp
+++ /dev/null
@@ -1,39 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#include "RefMeanUint8Workload.hpp"
-
-#include "Mean.hpp"
-#include "RefWorkloadUtils.hpp"
-
-#include "Profiling.hpp"
-
-#include <vector>
-
-namespace armnn
-{
-
-RefMeanUint8Workload::RefMeanUint8Workload(const MeanQueueDescriptor& descriptor, const WorkloadInfo& info)
-  :Uint8Workload<MeanQueueDescriptor>(descriptor, info) {}
-
-
-void RefMeanUint8Workload::Execute() const
-{
-    ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefMeanUint8Workload_Execute");
-
-    const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]);
-    const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
-
-    auto dequant = Dequantize(GetInputTensorDataU8(0, m_Data), inputInfo);
-
-    std::vector<float> results(outputInfo.GetNumElements());
-
-    Mean(inputInfo, outputInfo, m_Data.m_Parameters.m_Axis, dequant.data(), results.data());
-
-    Quantize(GetOutputTensorDataU8(0, m_Data), results.data(), outputInfo);
-}
-
-} //namespace armnn
-
diff --git a/src/backends/reference/workloads/RefMeanUint8Workload.hpp b/src/backends/reference/workloads/RefMeanUint8Workload.hpp
deleted file mode 100644
index f53b8a4..0000000
--- a/src/backends/reference/workloads/RefMeanUint8Workload.hpp
+++ /dev/null
@@ -1,21 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#pragma once
-
-#include "backendsCommon/Workload.hpp"
-#include "backendsCommon/WorkloadData.hpp"
-
-namespace armnn
-{
-
-class RefMeanUint8Workload : public Uint8Workload<MeanQueueDescriptor>
-{
-public:
-    explicit RefMeanUint8Workload (const MeanQueueDescriptor& descriptor, const WorkloadInfo& info);
-    virtual void Execute() const override;
-};
-
-} //namespace armnn
diff --git a/src/backends/reference/workloads/RefMeanWorkload.cpp b/src/backends/reference/workloads/RefMeanWorkload.cpp
new file mode 100644
index 0000000..375ab39
--- /dev/null
+++ b/src/backends/reference/workloads/RefMeanWorkload.cpp
@@ -0,0 +1,34 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "RefMeanWorkload.hpp"
+
+#include "Mean.hpp"
+#include "RefWorkloadUtils.hpp"
+
+#include "Profiling.hpp"
+
+#include <vector>
+
+namespace armnn
+{
+
+RefMeanWorkload::RefMeanWorkload(const MeanQueueDescriptor& descriptor, const WorkloadInfo& info)
+  :BaseWorkload<MeanQueueDescriptor>(descriptor, info) {}
+
+void RefMeanWorkload::Execute() const
+{
+    ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefMeanWorkload_Execute");
+
+    const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]);
+    const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
+
+    auto inputDecoder  = MakeDecoder<float>(inputInfo,  m_Data.m_Inputs[0]->Map());
+    auto outputEncoder = MakeEncoder<float>(outputInfo, m_Data.m_Outputs[0]->Map());
+
+    Mean(inputInfo, outputInfo, m_Data.m_Parameters.m_Axis, *inputDecoder, *outputEncoder);
+}
+
+} //namespace armnn
diff --git a/src/backends/reference/workloads/RefMeanWorkload.hpp b/src/backends/reference/workloads/RefMeanWorkload.hpp
new file mode 100644
index 0000000..eb4b407
--- /dev/null
+++ b/src/backends/reference/workloads/RefMeanWorkload.hpp
@@ -0,0 +1,24 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include "backendsCommon/Workload.hpp"
+#include "backendsCommon/WorkloadData.hpp"
+
+#include "Decoders.hpp"
+#include "Encoders.hpp"
+
+namespace armnn
+{
+
+class RefMeanWorkload : public BaseWorkload<MeanQueueDescriptor>
+{
+public:
+    explicit RefMeanWorkload (const MeanQueueDescriptor& descriptor, const WorkloadInfo& info);
+    virtual void Execute() const override;
+};
+
+} //namespace armnn
diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp
index 7cfced4..b141291 100644
--- a/src/backends/reference/workloads/RefWorkloads.hpp
+++ b/src/backends/reference/workloads/RefWorkloads.hpp
@@ -42,8 +42,7 @@
 #include "RefLstmWorkload.hpp"
 #include "RefConvertFp16ToFp32Workload.hpp"
 #include "RefConvertFp32ToFp16Workload.hpp"
-#include "RefMeanUint8Workload.hpp"
-#include "RefMeanFloat32Workload.hpp"
+#include "RefMeanWorkload.hpp"
 #include "RefPadWorkload.hpp"
 #include "RefBatchToSpaceNdUint8Workload.hpp"
 #include "RefBatchToSpaceNdFloat32Workload.hpp"