IVGCVSW-1831 - Add dimension check to MeanQueueDescriptor::Validate to check if the output dimension is correct from a given input and options.

Change-Id: Ibc15d9ea3151a7ba1935feafeb1843ee035e7f2e
diff --git a/src/armnn/backends/WorkloadData.cpp b/src/armnn/backends/WorkloadData.cpp
index c934a53..3ed77da 100644
--- a/src/armnn/backends/WorkloadData.cpp
+++ b/src/armnn/backends/WorkloadData.cpp
@@ -129,6 +129,18 @@
     }
 }
 
+void ValidateTensorMaxNumElements(const TensorInfo& tensor,
+                                  std::string const& descName,
+                                  unsigned int maxNumElements,
+                                  std::string const& tensorName)
+{
+    if (tensor.GetNumElements() > maxNumElements)
+    {
+        throw InvalidArgumentException(descName + ": Expected maximum of " + to_string(maxNumElements) + " but got " +
+            to_string(tensor.GetNumElements()) + " elements for " + tensorName + " tensor.");
+    }
+}
+
 //---------------------------------------------------------------
 void ValidateTensorDataType(const TensorInfo& tensor, DataType dataType,
     const std::string& descName, std::string const& tensorName)
@@ -828,6 +840,29 @@
 {
     ValidateSingleInput(workloadInfo, "MeanQueueDescriptor");
     ValidateSingleOutput(workloadInfo, "MeanQueueDescriptor");
+
+    const TensorInfo& input  = workloadInfo.m_InputTensorInfos[0];
+    const TensorInfo& output = workloadInfo.m_OutputTensorInfos[0];
+
+    if (m_Keepdims)
+    {
+        ValidateTensorNumDimensions(output, "MeanQueueDescriptor", input.GetNumDimensions(), "output");
+    }
+    else if (m_Axis == nullptr)
+    {
+        ValidateTensorNumDimensions(output, "MeanQueueDescriptor", 1, "output");
+    }
+    else
+    {
+        const TensorInfo& axis = m_Axis->GetTensorInfo();
+        ValidateTensorNumDimensions(axis, "MeanQueueDescriptor", 1, "axis");
+        ValidateTensorMaxNumElements(axis, "MeanQueueDescriptor", input.GetNumDimensions(), "axis");
+        unsigned int outputDim = input.GetNumDimensions() - axis.GetNumElements();
+        ValidateTensorNumDimensions(output,
+                                    "MeanQueueDescriptor",
+                                    outputDim > 0 ? outputDim : 1,
+                                    "output");
+    }
 }
 
 } //namespace armnn
diff --git a/src/armnn/backends/WorkloadData.hpp b/src/armnn/backends/WorkloadData.hpp
index f8f7e32..face761 100644
--- a/src/armnn/backends/WorkloadData.hpp
+++ b/src/armnn/backends/WorkloadData.hpp
@@ -199,6 +199,15 @@
 // Mean layer workload data.
 struct MeanQueueDescriptor : QueueDescriptor
 {
+    MeanQueueDescriptor()
+        : m_Axis(nullptr)
+        , m_Keepdims(false)
+    {
+    }
+
+    const ConstCpuTensorHandle* m_Axis;
+    bool m_Keepdims;
+
     void Validate(const WorkloadInfo& workloadInfo) const;
 };