blob: acd089aef8541d4232af5cdb43e82fee680782cb [file] [log] [blame]
Samuel Yap4b7a34d2022-07-06 15:36:03 +01001//
2// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#include "BatchMatMulLayer.hpp"
6
7#include <armnn/backends/WorkloadFactory.hpp>
Samuel Yap75d2cb12022-08-08 14:07:42 +01008#include <armnnUtils/Permute.hpp>
Samuel Yap4b7a34d2022-07-06 15:36:03 +01009#include "layers/LayerCloneBase.hpp"
10
11namespace armnn
12{
13
14BatchMatMulLayer::BatchMatMulLayer(const BatchMatMulDescriptor& param, const char* name)
15 : LayerWithParameters(2, 1, LayerType::BatchMatMul, param, name)
16{}
17
18std::unique_ptr<IWorkload> BatchMatMulLayer::CreateWorkload(const IWorkloadFactory& factory) const
19{
20 BatchMatMulQueueDescriptor descriptor;
21 SetAdditionalInfo(descriptor);
22
23 return factory.CreateWorkload(LayerType::BatchMatMul, descriptor, PrepInfoAndDesc(descriptor));
24}
25
26BatchMatMulLayer* BatchMatMulLayer::Clone(Graph& graph) const
27{
28 auto layer = CloneBase<BatchMatMulLayer>(graph, m_Param, GetName());
29
30 return std::move(layer);
31}
32
33std::vector<TensorShape> BatchMatMulLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const
34{
35 ARMNN_ASSERT(inputShapes.size() == 2);
36
37 TensorShape inputXShape = inputShapes[0];
38 TensorShape inputYShape = inputShapes[1];
39
Samuel Yap75d2cb12022-08-08 14:07:42 +010040 // Adjoint will not affect the resultant shape, as you would be permuting two axes of equal size
41 if(m_Param.m_TransposeX)
42 {
43 auto permuteVec = BatchMatMulDescriptor::GetPermuteVec(m_Param.m_DataLayoutX,
44 inputXShape);
45 inputXShape = armnnUtils::Permuted(inputXShape, permuteVec);
46 }
47 if(m_Param.m_TransposeY)
48 {
49 auto permuteVec = BatchMatMulDescriptor::GetPermuteVec(m_Param.m_DataLayoutY,
50 inputYShape);
51 inputYShape = armnnUtils::Permuted(inputYShape, permuteVec);
52 }
Samuel Yap4b7a34d2022-07-06 15:36:03 +010053
54 TensorShape& longerInput = inputXShape.GetNumDimensions() >= inputYShape.GetNumDimensions()?
Samuel Yap75d2cb12022-08-08 14:07:42 +010055 inputXShape : inputYShape;
Samuel Yap4b7a34d2022-07-06 15:36:03 +010056 TensorShape& shorterInput = inputXShape.GetNumDimensions() >= inputYShape.GetNumDimensions()?
Samuel Yap75d2cb12022-08-08 14:07:42 +010057 inputYShape : inputXShape;
Samuel Yap4b7a34d2022-07-06 15:36:03 +010058
59 unsigned int inputNumDimsOffset = longerInput.GetNumDimensions() - shorterInput.GetNumDimensions();
60
61 unsigned int outputNumDimensions = longerInput.GetNumDimensions();
62
63 std::vector<unsigned int> tensorDimensions(outputNumDimensions, 0);
64
Samuel Yap75d2cb12022-08-08 14:07:42 +010065 const auto& longerInputDataLayout = inputXShape.GetNumDimensions() >= inputYShape.GetNumDimensions()?
66 m_Param.m_DataLayoutX : m_Param.m_DataLayoutY;
67 auto longerAxesToMul = BatchMatMulDescriptor::GetAxesToMul(longerInputDataLayout,
68 longerInput);
Samuel Yap4b7a34d2022-07-06 15:36:03 +010069
70 for (unsigned int i = 0; i < outputNumDimensions; ++i)
71 {
72 if (i == longerAxesToMul.first)
73 {
74 tensorDimensions[i] = &shorterInput == &inputXShape ? inputXShape[i - inputNumDimsOffset] : inputXShape[i];
75 }
76 else if(i == longerAxesToMul.second)
77 {
78 tensorDimensions[i] = &shorterInput == &inputYShape ? inputYShape[i - inputNumDimsOffset] : inputYShape[i];
79 }
80 else // The other dimensions not to be multiplied (but may be broadcasted)
81 {
82 // Does NOT validate whether it's a valid broadcast - that's done in the validate func in WorkloadData.cpp
83 tensorDimensions[i] = static_cast<int>(i) - static_cast<int>(inputNumDimsOffset) < 0 ?
84 longerInput[i] :
85 std::max(longerInput[i], shorterInput[i - inputNumDimsOffset]);
86 }
87 }
88
89 auto outputShape = TensorShape(outputNumDimensions, tensorDimensions.data());
90 return std::vector<TensorShape>({ outputShape });
91}
92
93void BatchMatMulLayer::ValidateTensorShapesFromInputs()
94{
95 VerifyLayerConnections(2, CHECK_LOCATION());
96
97 const TensorShape& outputShape = GetOutputSlot(0).GetTensorInfo().GetShape();
98
99 VerifyShapeInferenceType(outputShape, m_ShapeInferenceMethod);
100
101 auto inferredShapes = InferOutputShapes({
102 GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(),
103 GetInputSlot(1).GetConnection()->GetTensorInfo().GetShape() });
104
105 ARMNN_ASSERT(inferredShapes.size() == 1);
106
107 ValidateAndCopyShape(outputShape, inferredShapes[0], m_ShapeInferenceMethod, "BatchMatMulLayer");
108}
109
110} // namespace armnn