blob: 3d8651f995df1c8ba7a43d8d48e14c20e84fc127 [file] [log] [blame]
Teresa Charlin0f86ecf2022-10-13 15:47:08 +01001//
2// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "NeonBatchMatMulWorkload.hpp"
7
8#include "NeonWorkloadUtils.hpp"
9
10#include <armnn/utility/PolymorphicDowncast.hpp>
11
12#include <armnnUtils/Permute.hpp>
13
14#include <backendsCommon/WorkloadUtils.hpp>
15
16#include <arm_compute/runtime/NEON/functions/NEGEMM.h>
17
18#include <arm_compute/runtime/NEON/functions/NEPermute.h>
19
20
21namespace armnn
22{
23arm_compute::Status NeonBatchMatMulValidate(const TensorInfo& inputX,
24 const TensorInfo& inputY,
25 const TensorInfo& output,
26 const BatchMatMulDescriptor& descriptor)
27{
28 if (descriptor.m_AdjointX || descriptor.m_AdjointY )
29 {
30 throw Exception("Support for adjoint not implemented.");
31 }
32 if (descriptor.m_DataLayoutX != armnn::DataLayout::NCHW || descriptor.m_DataLayoutY != armnn::DataLayout::NCHW )
33 {
34 throw Exception("Only supported the MatMul in the last 2 dimensions");
35 }
36
37 const auto aclInputXInfo = armcomputetensorutils::BuildArmComputeTensorInfo(inputX, descriptor.m_DataLayoutX);
38 const auto aclInputYInfo = armcomputetensorutils::BuildArmComputeTensorInfo(inputY, descriptor.m_DataLayoutY);
39 const auto aclOutputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(output);
40
41 arm_compute::Status statusGEMM = arm_compute::Status(arm_compute::ErrorCode::OK);
42 arm_compute::Status statusPermuteX = arm_compute::Status(arm_compute::ErrorCode::OK);
43 arm_compute::Status statusPermuteY = arm_compute::Status(arm_compute::ErrorCode::OK);
44
45 arm_compute::TensorInfo aclPermutedXInfo = arm_compute::TensorInfo();
46 arm_compute::TensorInfo aclPermutedYInfo = arm_compute::TensorInfo();
47
48 if (descriptor.m_TransposeX == true)
49 {
50 auto permutationXVector = GeneratePermutationVectorOnLastTwoDimensions(inputX.GetNumDimensions());
51 const auto aclPermutationXVector = armcomputetensorutils::BuildArmComputePermutationVector(permutationXVector);
52 const TensorInfo permutedXInfo = armnnUtils::Permuted(inputX, permutationXVector);
53 aclPermutedXInfo = armcomputetensorutils::BuildArmComputeTensorInfo(permutedXInfo);
54
55 statusPermuteX = arm_compute::NEPermute::validate(&aclInputXInfo,
56 &aclPermutedXInfo,
57 aclPermutationXVector);
58 }
59
60 if (descriptor.m_TransposeY == true)
61 {
62 auto permutationYVector = GeneratePermutationVectorOnLastTwoDimensions(inputY.GetNumDimensions());
63 const auto aclPermutationYVector = armcomputetensorutils::BuildArmComputePermutationVector(permutationYVector);
64 const TensorInfo permutedYInfo = armnnUtils::Permuted(inputY, permutationYVector);
65 aclPermutedYInfo = armcomputetensorutils::BuildArmComputeTensorInfo(permutedYInfo);
66
67 statusPermuteY = arm_compute::NEPermute::validate(&aclInputYInfo,
68 &aclPermutedYInfo,
69 aclPermutationYVector);
70 }
71
72 const arm_compute::GEMMInfo& gemm_info = arm_compute::GEMMInfo(false, // is inputX reshaped
73 false, // is inputY reshaped
74 false); // is inputY reshaped only 1st run
75
76 statusGEMM = arm_compute::NEGEMM::validate(descriptor.m_TransposeX ? &aclPermutedXInfo : &aclInputXInfo,
77 descriptor.m_TransposeY ? &aclPermutedYInfo : &aclInputYInfo,
78 nullptr,
79 &aclOutputInfo,
80 1.0,
81 0,
82 gemm_info);
83
84 if (statusPermuteX.error_code() == arm_compute::ErrorCode::OK &&
85 statusPermuteY.error_code() == arm_compute::ErrorCode::OK &&
86 statusGEMM.error_code() == arm_compute::ErrorCode::OK)
87 {
88 return arm_compute::Status(arm_compute::ErrorCode::OK,
89 "All BatchMatMul layers validate status OK.");
90 }
91 else
92 {
93 return arm_compute::Status(arm_compute::ErrorCode::RUNTIME_ERROR,
94 "BatchMatMul layer validate status failed."
95 + statusGEMM.error_description()
96 + statusPermuteX.error_description()
97 + statusPermuteY.error_description());
98 }
99
100}
101
102NeonBatchMatMulWorkload::NeonBatchMatMulWorkload(
103 const BatchMatMulQueueDescriptor& descriptor, const WorkloadInfo& info)
104 : NeonBaseWorkload<BatchMatMulQueueDescriptor>(descriptor, info)
105{
106 if (descriptor.m_Parameters.m_AdjointX || descriptor.m_Parameters.m_AdjointY )
107 {
108 throw Exception("Support for adjoint not implemented.");
109 }
110 if (descriptor.m_Parameters.m_DataLayoutX != armnn::DataLayout::NCHW ||
111 descriptor.m_Parameters.m_DataLayoutY != armnn::DataLayout::NCHW )
112 {
113 throw Exception("Only supported the MatMul in the last 2 dimensions");
114 }
115
116 // Report Profiling Details
117 ARMNN_REPORT_PROFILING_WORKLOAD_DESC("NeonBatchMatMulWorkload_Construct",
118 descriptor.m_Parameters,
119 info,
120 this->GetGuid());
121
122 m_Data.ValidateInputsOutputs("NeonBatchMatMulWorkload", 2, 1);
123
124 arm_compute::ITensor& inputX = PolymorphicDowncast<IAclTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
125 arm_compute::ITensor& inputY = PolymorphicDowncast<IAclTensorHandle*>(m_Data.m_Inputs[1])->GetTensor();
126 auto outputHandle = PolymorphicDowncast<IAclTensorHandle*>(m_Data.m_Outputs[0]);
127 arm_compute::ITensor& output = outputHandle->GetTensor();
128
129 arm_compute::DataLayout aclDataLayoutX = ConvertDataLayout(m_Data.m_Parameters.m_DataLayoutX);
130 arm_compute::DataLayout aclDataLayoutY = ConvertDataLayout(m_Data.m_Parameters.m_DataLayoutY);
131
132 inputX.info()->set_data_layout(aclDataLayoutX);
133 inputY.info()->set_data_layout(aclDataLayoutY);
134
135 if (descriptor.m_Parameters.m_TransposeX == true)
136 {
137 armnn::PermutationVector permutationXVector
138 = GeneratePermutationVectorOnLastTwoDimensions(info.m_InputTensorInfos[0].GetNumDimensions());
139 const TensorInfo permutedXInfo = armnnUtils::Permuted(info.m_InputTensorInfos[0], permutationXVector);
140 const auto aclPermutationXVector = armcomputetensorutils::BuildArmComputePermutationVector(permutationXVector);
141
142 auto permuteLayerX = std::make_unique<arm_compute::NEPermute>();
143 BuildArmComputeTensor(m_PermutedTensorX, permutedXInfo);
144 InitialiseArmComputeTensorEmpty(m_PermutedTensorX);
145 permuteLayerX->configure(&inputX, &m_PermutedTensorX, aclPermutationXVector);
146 m_PermuteLayerX.reset(permuteLayerX.release());
147 }
148
149 if (descriptor.m_Parameters.m_TransposeY == true)
150 {
151 armnn::PermutationVector permutationYVector
152 = GeneratePermutationVectorOnLastTwoDimensions(info.m_InputTensorInfos[1].GetNumDimensions());
153 const TensorInfo permutedYInfo = armnnUtils::Permuted(info.m_InputTensorInfos[1], permutationYVector);
154 const auto aclPermutationYVector = armcomputetensorutils::BuildArmComputePermutationVector(permutationYVector);
155
156 auto permuteLayerY = std::make_unique<arm_compute::NEPermute>();
157 BuildArmComputeTensor(m_PermutedTensorY, permutedYInfo);
158 InitialiseArmComputeTensorEmpty(m_PermutedTensorY);
159 permuteLayerY->configure(&inputY, &m_PermutedTensorY, aclPermutationYVector);
160 m_PermuteLayerY.reset(permuteLayerY.release());
161 }
162
163 const arm_compute::GEMMInfo& gemm_info = arm_compute::GEMMInfo(false, // is inputX reshaped
164 false, // is inputY reshaped
165 false); // is inputY reshaped only 1st run
166 auto gemmLayer = std::make_unique<arm_compute::NEGEMM>();
167 gemmLayer->configure(descriptor.m_Parameters.m_TransposeX ? &m_PermutedTensorX : &inputX,
168 descriptor.m_Parameters.m_TransposeY ? &m_PermutedTensorY : &inputY,
169 nullptr,
170 &output,
171 1.0,
172 0,
173 gemm_info);
174 m_GEMMLayer.reset(gemmLayer.release());
175}
176
177void NeonBatchMatMulWorkload::Execute() const
178{
179 ARMNN_SCOPED_PROFILING_EVENT_NEON_GUID("NeonBatchMatMulWorkload_Execute", this->GetGuid());
180 if (m_PermuteLayerX)
181 {
182 m_PermuteLayerX->run();
183 }
184 if (m_PermuteLayerY)
185 {
186 m_PermuteLayerY->run();
187 }
188 m_GEMMLayer->run();
189}
190} //namespace armnn