blob: f21666b90a1a0e7492860d1de06bc35399288436 [file] [log] [blame]
Teresa Charlin94916a52022-10-19 08:48:07 +01001//
Teresa Charlinc074cdb2023-01-13 18:44:00 +00002// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
Teresa Charlin94916a52022-10-19 08:48:07 +01003// SPDX-License-Identifier: MIT
4//
5
6#include "ClBatchMatMulWorkload.hpp"
7
8#include "ClWorkloadUtils.hpp"
9
10#include <aclCommon/ArmComputeTensorUtils.hpp>
11#include <aclCommon/ArmComputeUtils.hpp>
12
13#include <armnn/utility/PolymorphicDowncast.hpp>
14
15#include <armnnUtils/Permute.hpp>
Mike Kelly0e3fe102023-01-23 19:32:06 +000016#include <armnnUtils/TensorUtils.hpp>
Teresa Charlin94916a52022-10-19 08:48:07 +010017
18#include <backendsCommon/WorkloadUtils.hpp>
19
20#include <cl/ClTensorHandle.hpp>
21
22#include <arm_compute/runtime/CL/functions/CLGEMM.h>
23#include <arm_compute/runtime/CL/functions/CLPermute.h>
24
25
26namespace armnn
27{
Mike Kelly0e3fe102023-01-23 19:32:06 +000028
Teresa Charlin94916a52022-10-19 08:48:07 +010029arm_compute::Status ClBatchMatMulValidate(const TensorInfo& inputX,
30 const TensorInfo& inputY,
31 const TensorInfo& output,
32 const BatchMatMulDescriptor& descriptor)
33{
34 if (descriptor.m_AdjointX || descriptor.m_AdjointY )
35 {
36 throw Exception("Support for adjoint not implemented.");
37 }
38 if (descriptor.m_DataLayoutX != armnn::DataLayout::NCHW || descriptor.m_DataLayoutY != armnn::DataLayout::NCHW )
39 {
40 throw Exception("Only supported the MatMul in the last 2 dimensions");
41 }
42
43 arm_compute::Status statusGEMM = arm_compute::Status(arm_compute::ErrorCode::OK);
44 arm_compute::Status statusPermuteX = arm_compute::Status(arm_compute::ErrorCode::OK);
45 arm_compute::Status statusPermuteY = arm_compute::Status(arm_compute::ErrorCode::OK);
46
Mike Kelly0e3fe102023-01-23 19:32:06 +000047 // ClGemmMatrixMultiplyNativeKernel used by CLGEMM can only support 3 dimensional
48 // tensors so try to reduce the dimensions to 3
49 const auto aclInputXInfo = armcomputetensorutils::BuildArmComputeTensorInfo(inputX, descriptor.m_DataLayoutX, 3);
50 const auto aclInputYInfo = armcomputetensorutils::BuildArmComputeTensorInfo(inputY, descriptor.m_DataLayoutY, 3);
51 const auto aclOutputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(output, descriptor.m_DataLayoutY, 3);
Teresa Charlin94916a52022-10-19 08:48:07 +010052
53 arm_compute::TensorInfo aclPermutedXInfo = arm_compute::TensorInfo();
54 arm_compute::TensorInfo aclPermutedYInfo = arm_compute::TensorInfo();
55
56 if (descriptor.m_TransposeX == true)
57 {
Mike Kelly0e3fe102023-01-23 19:32:06 +000058 armnn::TensorInfo inputXStripped = armnnUtils::ReduceDims(inputX, 3);
59
60 auto permutationXVector = GeneratePermutationVectorOnLastTwoDimensions(inputXStripped.GetNumDimensions());
Teresa Charlin94916a52022-10-19 08:48:07 +010061 const auto aclPermutationXVector = armcomputetensorutils::BuildArmComputePermutationVector(permutationXVector);
Mike Kelly0e3fe102023-01-23 19:32:06 +000062 const TensorInfo permutedXInfo = armnnUtils::Permuted(inputXStripped, permutationXVector);
63 aclPermutedXInfo = armcomputetensorutils::BuildArmComputeTensorInfo(permutedXInfo, 3);
Teresa Charlin94916a52022-10-19 08:48:07 +010064
65 statusPermuteX = arm_compute::CLPermute::validate(&aclInputXInfo,
66 &aclPermutedXInfo,
67 aclPermutationXVector);
68 }
69
Mike Kelly0e3fe102023-01-23 19:32:06 +000070 if (descriptor.m_TransposeY == true)
Teresa Charlin94916a52022-10-19 08:48:07 +010071 {
Mike Kelly0e3fe102023-01-23 19:32:06 +000072 armnn::TensorInfo inputYStripped = armnnUtils::ReduceDims(inputY, 3);
73
74 auto permutationYVector = GeneratePermutationVectorOnLastTwoDimensions(inputYStripped.GetNumDimensions());
Teresa Charlin94916a52022-10-19 08:48:07 +010075 const auto aclPermutationYVector = armcomputetensorutils::BuildArmComputePermutationVector(permutationYVector);
Mike Kelly0e3fe102023-01-23 19:32:06 +000076 const TensorInfo permutedYInfo = armnnUtils::Permuted(inputYStripped, permutationYVector);
77 aclPermutedYInfo = armcomputetensorutils::BuildArmComputeTensorInfo(permutedYInfo, 3);
Teresa Charlin94916a52022-10-19 08:48:07 +010078
79 statusPermuteY = arm_compute::CLPermute::validate(&aclInputYInfo,
80 &aclPermutedYInfo,
81 aclPermutationYVector);
Teresa Charlin94916a52022-10-19 08:48:07 +010082 }
83
84 const arm_compute::GEMMInfo& gemm_info = arm_compute::GEMMInfo(false, // is inputX reshaped
85 false, // is inputY reshaped
86 false); // is inputY reshaped only 1st run
87
88
89 statusGEMM = arm_compute::CLGEMM::validate(descriptor.m_TransposeX ? &aclPermutedXInfo : &aclInputXInfo,
90 descriptor.m_TransposeY ? &aclPermutedYInfo : &aclInputYInfo,
91 nullptr,
92 &aclOutputInfo,
93 1.0,
94 0,
95 gemm_info);
96
97 if (statusPermuteX.error_code() == arm_compute::ErrorCode::OK &&
98 statusPermuteY.error_code() == arm_compute::ErrorCode::OK &&
99 statusGEMM.error_code() == arm_compute::ErrorCode::OK)
100 {
101 return arm_compute::Status(arm_compute::ErrorCode::OK,
102 "All Batch Mat Mul layers validate status OK.");
103 }
104 else
105 {
106 return arm_compute::Status(arm_compute::ErrorCode::RUNTIME_ERROR,
107 "BatchMatMul layer validate status failed."
108 + statusGEMM.error_description()
109 + statusPermuteX.error_description()
110 + statusPermuteY.error_description());
111 }
112
113}
114
115ClBatchMatMulWorkload::ClBatchMatMulWorkload(const BatchMatMulQueueDescriptor& descriptor,
116 const WorkloadInfo& info,
117 const arm_compute::CLCompileContext& clCompileContext)
118 : ClBaseWorkload<BatchMatMulQueueDescriptor>(descriptor, info)
119{
120 // Report Profiling Details
121 ARMNN_REPORT_PROFILING_WORKLOAD_DESC("ClBatchMatMulWorkload_Construct",
122 descriptor.m_Parameters,
123 info,
124 this->GetGuid());
125
126 if (descriptor.m_Parameters.m_AdjointX || descriptor.m_Parameters.m_AdjointY )
127 {
128 throw Exception("Support for adjoint not implemented.");
129 }
130 if (descriptor.m_Parameters.m_DataLayoutX != armnn::DataLayout::NCHW ||
131 descriptor.m_Parameters.m_DataLayoutY != armnn::DataLayout::NCHW )
132 {
133 throw Exception("Only supported the MatMul in the last 2 dimensions");
134 }
135
136 m_Data.ValidateInputsOutputs("ClBatchMatMulWorkload", 2, 1);
137
138 const arm_compute::ICLTensor& inputX = PolymorphicDowncast<ClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
139 const arm_compute::ICLTensor& inputY = PolymorphicDowncast<ClTensorHandle*>(m_Data.m_Inputs[1])->GetTensor();
140 arm_compute::ICLTensor& output = PolymorphicDowncast<ClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
141
142 inputX.info()->set_data_layout(armcomputetensorutils::ConvertDataLayout(m_Data.m_Parameters.m_DataLayoutX));
Mike Kelly0e3fe102023-01-23 19:32:06 +0000143 arm_compute::TensorShape inputXTensorInfo = armcomputetensorutils::BuildArmComputeTensorShape(
144 info.m_InputTensorInfos[0].GetShape(), 3);
145 inputX.info()->set_tensor_shape(inputXTensorInfo);
Teresa Charlin94916a52022-10-19 08:48:07 +0100146 inputY.info()->set_data_layout(armcomputetensorutils::ConvertDataLayout(m_Data.m_Parameters.m_DataLayoutY));
Mike Kelly0e3fe102023-01-23 19:32:06 +0000147 arm_compute::TensorShape inputYTensorInfo = armcomputetensorutils::BuildArmComputeTensorShape(
148 info.m_InputTensorInfos[1].GetShape(), 3);
149 inputY.info()->set_tensor_shape(inputYTensorInfo);
Teresa Charlin94916a52022-10-19 08:48:07 +0100150
151 arm_compute::TensorInfo aclPermutedXInfo = arm_compute::TensorInfo();
152 arm_compute::TensorInfo aclPermutedYInfo = arm_compute::TensorInfo();
153
154 if (descriptor.m_Parameters.m_TransposeX == true)
155 {
Mike Kelly0e3fe102023-01-23 19:32:06 +0000156 armnn::TensorInfo strippedInfo = armnnUtils::ReduceDims(info.m_InputTensorInfos[0], 3);
157
Teresa Charlin94916a52022-10-19 08:48:07 +0100158 armnn::PermutationVector permutationXVector
Mike Kelly0e3fe102023-01-23 19:32:06 +0000159 = GeneratePermutationVectorOnLastTwoDimensions(strippedInfo.GetNumDimensions());
160 const TensorInfo permutedXInfo = armnnUtils::Permuted(strippedInfo, permutationXVector);
Teresa Charlin94916a52022-10-19 08:48:07 +0100161 const auto aclPermutationXVector = armcomputetensorutils::BuildArmComputePermutationVector(permutationXVector);
162 armcomputetensorutils::BuildArmComputeTensor(m_PermutedTensorX, permutedXInfo);
163 armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_PermutedTensorX);
164
165 auto permuteLayerX = std::make_unique<arm_compute::CLPermute>();
166 permuteLayerX->configure(clCompileContext,
167 &inputX,
168 &m_PermutedTensorX,
169 aclPermutationXVector);
170 m_PermuteLayerX.reset(permuteLayerX.release());
171 }
172
173 if (descriptor.m_Parameters.m_TransposeY == true)
174 {
Mike Kelly0e3fe102023-01-23 19:32:06 +0000175 armnn::TensorInfo strippedInfo = armnnUtils::ReduceDims(info.m_InputTensorInfos[1], 3);
176
Teresa Charlin94916a52022-10-19 08:48:07 +0100177 armnn::PermutationVector permutationYVector
Mike Kelly0e3fe102023-01-23 19:32:06 +0000178 = GeneratePermutationVectorOnLastTwoDimensions(strippedInfo.GetNumDimensions());
179 const TensorInfo permutedYInfo = armnnUtils::Permuted(strippedInfo, permutationYVector);
Teresa Charlin94916a52022-10-19 08:48:07 +0100180 const auto aclPermutationYVector = armcomputetensorutils::BuildArmComputePermutationVector(permutationYVector);
181 armcomputetensorutils::BuildArmComputeTensor(m_PermutedTensorY, permutedYInfo);
182 armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_PermutedTensorY);
183
Teresa Charlinc074cdb2023-01-13 18:44:00 +0000184 auto permuteLayerY = std::make_unique<arm_compute::CLPermute>();
Teresa Charlin94916a52022-10-19 08:48:07 +0100185 permuteLayerY->configure(clCompileContext,
186 &inputY,
187 &m_PermutedTensorY,
188 aclPermutationYVector);
189 m_PermuteLayerY.reset(permuteLayerY.release());
190 }
191
192 const arm_compute::GEMMInfo& gemm_info = arm_compute::GEMMInfo(false, // is inputX reshaped
193 false, // is inputY reshaped
194 false); // is inputY reshaped only 1st run
195 auto gemmLayer = std::make_unique<arm_compute::CLGEMM>();
196 gemmLayer->configure(clCompileContext,
197 descriptor.m_Parameters.m_TransposeX ? &m_PermutedTensorX : &inputX,
198 descriptor.m_Parameters.m_TransposeY ? &m_PermutedTensorY : &inputY,
199 nullptr,
200 &output,
201 1.0,
202 0,
203 gemm_info);
204 m_GEMMLayer.reset(gemmLayer.release());
205}
206
207void ClBatchMatMulWorkload::Execute() const
208{
209 ARMNN_SCOPED_PROFILING_EVENT_CL_GUID("ClBatchMatMulWorkload_Execute", this->GetGuid());
210 if (m_PermuteLayerX)
211 {
212 m_PermuteLayerX->run();
213 }
214 if (m_PermuteLayerY)
215 {
216 m_PermuteLayerY->run();
217 }
218 m_GEMMLayer->run();
219}
220} //namespace armnn