IVGCVSW-7109: Add Batch MatMul front end support - Reference

  * Descriptors added for BatchMatMul
  * Layer definition added
  * Input validation added (will likely change when opt. param support comes in)
  * Ref workload implementation for BatchMatMul added (will also change with opt. param support)
  * Ref layer tests made for BatchMatMul
  * CMake and other build files updated

Signed-off-by: Samuel Yap <samuel.yap@arm.com>
Change-Id: Ic885301da543ee0fbe7922b85e7f9658c4efc617
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index 606821b..9a4c60f 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -4143,5 +4143,232 @@
     }
 }
 
+void BatchMatMulQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
+{
+    const std::string descriptorName{"BatchMatMulDescriptor"};
+
+    ValidateNumInputs(workloadInfo,  descriptorName, 2);
+    ValidateNumOutputs(workloadInfo, descriptorName, 1);
+
+    // Inputs must be: both 2D+
+    // For inputs X and Y whose dimensions to be multiplied are (M,N) and (I,J) respectively,
+    // axes N and I must be the same size
+
+    const auto& inputTensorXInfo = workloadInfo.m_InputTensorInfos[0];
+    const auto& inputTensorYInfo = workloadInfo.m_InputTensorInfos[1];
+    const auto& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
+
+    std::vector<DataType> supportedTypes =
+    {
+        DataType::BFloat16,
+        DataType::Float16,
+        DataType::Float32,
+        DataType::QAsymmS8,
+        DataType::QAsymmU8,
+        DataType::QSymmS16
+    };
+
+    ValidateDataTypes(inputTensorXInfo, supportedTypes, descriptorName);
+    ValidateDataTypes(inputTensorYInfo, supportedTypes, descriptorName);
+    ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
+
+    if ((inputTensorXInfo.GetNumDimensions() < 2) ||
+        (inputTensorYInfo.GetNumDimensions() < 2))
+    {
+        throw InvalidArgumentException(descriptorName + ": Input tensors are not 2D or greater.");
+    }
+
+    if(m_Parameters.m_DataLayoutX.has_value())
+    {
+        switch(m_Parameters.m_DataLayoutX.value())
+        {
+            case DataLayout::NCHW:
+            case DataLayout::NHWC:
+                if(inputTensorXInfo.GetNumDimensions() != 4)
+                {
+                    throw InvalidArgumentException(descriptorName +
+                        ": Input tensor X does not have the correct "
+                        "number of dimensions for the Data Layout that it has been assigned.");
+                }
+                break;
+            case DataLayout::NCDHW:
+            case DataLayout::NDHWC:
+                if(inputTensorXInfo.GetNumDimensions() != 5)
+                {
+                    throw InvalidArgumentException(descriptorName +
+                        ": Input tensor X does not have the correct "
+                        "number of dimensions for the Data Layout that it has been assigned.");
+                }
+                break;
+            default:
+                break;
+        }
+    }
+
+    if(m_Parameters.m_DataLayoutY.has_value())
+    {
+        switch(m_Parameters.m_DataLayoutY.value())
+        {
+            case DataLayout::NCHW:
+            case DataLayout::NHWC:
+                if(inputTensorYInfo.GetNumDimensions() != 4)
+                {
+                    throw InvalidArgumentException(descriptorName +
+                        ": Input tensor Y does not have the correct "
+                        "number of dimensions for the Data Layout that it has been assigned.");
+                }
+                break;
+            case DataLayout::NCDHW:
+            case DataLayout::NDHWC:
+                if(inputTensorYInfo.GetNumDimensions() != 5)
+                {
+                    throw InvalidArgumentException(descriptorName +
+                        ": Input tensor Y does not have the correct "
+                        "number of dimensions for the Data Layout that it has been assigned.");
+                }
+                break;
+            default:
+                break;
+        }
+    }
+
+    auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters,
+                                                         inputTensorXInfo.GetShape(),
+                                                         inputTensorYInfo.GetShape());
+
+    if(inputTensorXInfo.GetShape()[axesToMul.first.second]
+       != inputTensorYInfo.GetShape()[axesToMul.second.first])
+    {
+        throw InvalidArgumentException(descriptorName +
+            ": The final axis of input tensor X must be the same size as "
+            "the second last axis of input tensor Y.");
+    }
+
+    auto axesNotMul = BatchMatMulDescriptor::GetAxesNotMul(m_Parameters,
+                                                           inputTensorXInfo.GetShape(),
+                                                           inputTensorYInfo.GetShape());
+
+    {   // Separate scope so we don't pollute the rest of the scope with our temp variables
+        // e.g. NHWC isnt compatible with NCHW as of now
+        DataLayout xLayout;
+        DataLayout yLayout;
+
+        if(m_Parameters.m_DataLayoutX == EmptyOptional())
+        {
+            xLayout = DataLayout::NCHW; // Not equivalent - I'm just concerned with the last 2 axes
+        }
+        else
+        {
+            xLayout = m_Parameters.m_DataLayoutX.value();
+        }
+
+        if(m_Parameters.m_DataLayoutY == EmptyOptional())
+        {
+            yLayout = DataLayout::NCHW;
+        }
+        else
+        {
+            yLayout = m_Parameters.m_DataLayoutY.value();
+        }
+
+        if(xLayout == DataLayout::NCHW || xLayout == DataLayout::NCDHW)
+        {
+            if(yLayout == DataLayout::NHWC || yLayout == DataLayout::NDHWC)
+            {
+                throw InvalidArgumentException(descriptorName +
+                    ": Invalid input tensor data layout combination.");
+            }
+        }
+        if(yLayout == DataLayout::NCHW || yLayout == DataLayout::NCDHW)
+        {
+            if(xLayout == DataLayout::NHWC || xLayout == DataLayout::NDHWC)
+            {
+                throw InvalidArgumentException(descriptorName +
+                    ": Invalid input tensor data layout combination.");
+            }
+        }
+    }
+
+    // Simulate aligning the ends of the matrix dims and prepending 1's to the beginning of the shorter one
+    unsigned int outputTensorDimSize = std::max(inputTensorXInfo.GetNumDimensions(),
+                                                inputTensorYInfo.GetNumDimensions());
+    if(outputTensorDimSize-2 > 0)
+    {
+        TensorInfo tiXNotMul = TensorInfo(TensorShape(outputTensorDimSize-2),
+                                          DataType::Float32);
+        TensorInfo tiYNotMul = TensorInfo(TensorShape(outputTensorDimSize-2),
+                                          DataType::Float32);
+        TensorInfo tiOutNotMul = TensorInfo(TensorShape(outputTensorDimSize-2),
+                                            DataType::Float32);
+
+        auto doAxisExtension = [&](std::vector<unsigned int> axisIndices, TensorInfo& ti)
+        {
+            auto sizeDiff = (outputTensorDimSize-2) - axisIndices.size();
+
+            for(unsigned int i = 0; i < sizeDiff; i++)
+            {
+                axisIndices.insert(axisIndices.begin(), 1);
+            }
+
+            for(unsigned int i = 0; i < ti.GetNumDimensions(); i++)
+            {
+                ti.GetShape()[i] = inputTensorXInfo.GetShape()[i];
+            }
+        };
+
+        doAxisExtension(axesNotMul.first, tiXNotMul);
+        doAxisExtension(axesNotMul.second, tiYNotMul);
+
+        for(unsigned int i = 0; i < tiOutNotMul.GetNumDimensions(); i++)
+        {
+            tiOutNotMul.GetShape()[i] = std::max(tiXNotMul.GetShape()[i],
+                                                 tiYNotMul.GetShape()[i]);
+        }
+
+        ValidateBroadcastTensorShapesMatch(tiXNotMul,
+                                           tiYNotMul,
+                                           tiOutNotMul,
+                                           descriptorName,
+                                           "input_X",
+                                           "input_Y");
+    }
+
+    // Also check descriptor parameter validity
+    // This will eventually be moved to the start of the function as explained below
+    if ((!m_Parameters.m_TransposeX.empty() && !m_Parameters.m_AdjointX.empty()) ||
+        (!m_Parameters.m_TransposeY.empty() && !m_Parameters.m_AdjointY.empty()))
+    {
+        throw InvalidArgumentException(descriptorName +
+            ": Invalid descriptor parameters - Transpose and Adjoint "
+            "vectors cannot both be true for a given input tensor.");
+    }
+
+    if(m_Parameters.m_TransposeX.size() != 0 && m_Parameters.m_TransposeX.size() != inputTensorXInfo.GetNumDimensions())
+    {
+        throw InvalidArgumentException(descriptorName +
+            ": Invalid descriptor parameter - Transpose X vector must be "
+            "the same size as tensor input X's dimensionality.");
+    }
+    if(m_Parameters.m_AdjointX.size() != 0 && m_Parameters.m_AdjointX.size() != inputTensorXInfo.GetNumDimensions())
+    {
+        throw InvalidArgumentException(descriptorName +
+            ": Invalid descriptor parameter - Adjoint X vector must be "
+            "the same size as tensor input X's dimensionality.");
+    }
+    if(m_Parameters.m_TransposeY.size() != 0 && m_Parameters.m_TransposeY.size() != inputTensorYInfo.GetNumDimensions())
+    {
+        throw InvalidArgumentException(descriptorName +
+            ": Invalid descriptor parameter - Transpose Y vector must be "
+            "the same size as tensor input Y's dimensionality.");
+    }
+    if(m_Parameters.m_AdjointY.size() != 0 && m_Parameters.m_AdjointY.size() != inputTensorXInfo.GetNumDimensions())
+    {
+        throw InvalidArgumentException(descriptorName +
+            ": Invalid descriptor parameter - Adjoint Y vector must be "
+            "the same size as tensor input Y's dimensionality.");
+    }
+    // Note: for adjoint/transpose, you'll need to do the validation atop the resultant permutation.
+}
+
 
 } // namespace armnn
\ No newline at end of file