blob: 4acdef5e5c81078df7890a87398ed9518093175e [file] [log] [blame]
Teresa Charlin94916a52022-10-19 08:48:07 +01001//
2// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "ClBatchMatMulWorkload.hpp"
7
8#include "ClWorkloadUtils.hpp"
9
10#include <aclCommon/ArmComputeTensorUtils.hpp>
11#include <aclCommon/ArmComputeUtils.hpp>
12
13#include <armnn/utility/PolymorphicDowncast.hpp>
14
15#include <armnnUtils/Permute.hpp>
16
17#include <backendsCommon/WorkloadUtils.hpp>
18
19#include <cl/ClTensorHandle.hpp>
20
21#include <arm_compute/runtime/CL/functions/CLGEMM.h>
22#include <arm_compute/runtime/CL/functions/CLPermute.h>
23
24
25namespace armnn
26{
27arm_compute::Status ClBatchMatMulValidate(const TensorInfo& inputX,
28 const TensorInfo& inputY,
29 const TensorInfo& output,
30 const BatchMatMulDescriptor& descriptor)
31{
32 if (descriptor.m_AdjointX || descriptor.m_AdjointY )
33 {
34 throw Exception("Support for adjoint not implemented.");
35 }
36 if (descriptor.m_DataLayoutX != armnn::DataLayout::NCHW || descriptor.m_DataLayoutY != armnn::DataLayout::NCHW )
37 {
38 throw Exception("Only supported the MatMul in the last 2 dimensions");
39 }
40
41 arm_compute::Status statusGEMM = arm_compute::Status(arm_compute::ErrorCode::OK);
42 arm_compute::Status statusPermuteX = arm_compute::Status(arm_compute::ErrorCode::OK);
43 arm_compute::Status statusPermuteY = arm_compute::Status(arm_compute::ErrorCode::OK);
44
45 const auto aclInputXInfo = armcomputetensorutils::BuildArmComputeTensorInfo(inputX, descriptor.m_DataLayoutX);
46 const auto aclInputYInfo = armcomputetensorutils::BuildArmComputeTensorInfo(inputY, descriptor.m_DataLayoutY);
47 const auto aclOutputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(output);
48
49 arm_compute::TensorInfo aclPermutedXInfo = arm_compute::TensorInfo();
50 arm_compute::TensorInfo aclPermutedYInfo = arm_compute::TensorInfo();
51
52 if (descriptor.m_TransposeX == true)
53 {
54 auto permutationXVector = GeneratePermutationVectorOnLastTwoDimensions(inputX.GetNumDimensions());
55 const auto aclPermutationXVector = armcomputetensorutils::BuildArmComputePermutationVector(permutationXVector);
56 const TensorInfo permutedXInfo = armnnUtils::Permuted(inputX, permutationXVector);
57 aclPermutedXInfo = armcomputetensorutils::BuildArmComputeTensorInfo(permutedXInfo);
58
59 statusPermuteX = arm_compute::CLPermute::validate(&aclInputXInfo,
60 &aclPermutedXInfo,
61 aclPermutationXVector);
62 }
63
64 if ( descriptor.m_TransposeY == true)
65 {
66 auto permutationYVector = GeneratePermutationVectorOnLastTwoDimensions(inputY.GetNumDimensions());
67 const auto aclPermutationYVector = armcomputetensorutils::BuildArmComputePermutationVector(permutationYVector);
68 const TensorInfo permutedYInfo = armnnUtils::Permuted(inputY, permutationYVector);
69 aclPermutedYInfo = armcomputetensorutils::BuildArmComputeTensorInfo(permutedYInfo);
70
71 statusPermuteY = arm_compute::CLPermute::validate(&aclInputYInfo,
72 &aclPermutedYInfo,
73 aclPermutationYVector);
74
75 }
76
77 const arm_compute::GEMMInfo& gemm_info = arm_compute::GEMMInfo(false, // is inputX reshaped
78 false, // is inputY reshaped
79 false); // is inputY reshaped only 1st run
80
81
82 statusGEMM = arm_compute::CLGEMM::validate(descriptor.m_TransposeX ? &aclPermutedXInfo : &aclInputXInfo,
83 descriptor.m_TransposeY ? &aclPermutedYInfo : &aclInputYInfo,
84 nullptr,
85 &aclOutputInfo,
86 1.0,
87 0,
88 gemm_info);
89
90 if (statusPermuteX.error_code() == arm_compute::ErrorCode::OK &&
91 statusPermuteY.error_code() == arm_compute::ErrorCode::OK &&
92 statusGEMM.error_code() == arm_compute::ErrorCode::OK)
93 {
94 return arm_compute::Status(arm_compute::ErrorCode::OK,
95 "All Batch Mat Mul layers validate status OK.");
96 }
97 else
98 {
99 return arm_compute::Status(arm_compute::ErrorCode::RUNTIME_ERROR,
100 "BatchMatMul layer validate status failed."
101 + statusGEMM.error_description()
102 + statusPermuteX.error_description()
103 + statusPermuteY.error_description());
104 }
105
106}
107
108ClBatchMatMulWorkload::ClBatchMatMulWorkload(const BatchMatMulQueueDescriptor& descriptor,
109 const WorkloadInfo& info,
110 const arm_compute::CLCompileContext& clCompileContext)
111 : ClBaseWorkload<BatchMatMulQueueDescriptor>(descriptor, info)
112{
113 // Report Profiling Details
114 ARMNN_REPORT_PROFILING_WORKLOAD_DESC("ClBatchMatMulWorkload_Construct",
115 descriptor.m_Parameters,
116 info,
117 this->GetGuid());
118
119 if (descriptor.m_Parameters.m_AdjointX || descriptor.m_Parameters.m_AdjointY )
120 {
121 throw Exception("Support for adjoint not implemented.");
122 }
123 if (descriptor.m_Parameters.m_DataLayoutX != armnn::DataLayout::NCHW ||
124 descriptor.m_Parameters.m_DataLayoutY != armnn::DataLayout::NCHW )
125 {
126 throw Exception("Only supported the MatMul in the last 2 dimensions");
127 }
128
129 m_Data.ValidateInputsOutputs("ClBatchMatMulWorkload", 2, 1);
130
131 const arm_compute::ICLTensor& inputX = PolymorphicDowncast<ClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
132 const arm_compute::ICLTensor& inputY = PolymorphicDowncast<ClTensorHandle*>(m_Data.m_Inputs[1])->GetTensor();
133 arm_compute::ICLTensor& output = PolymorphicDowncast<ClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
134
135 inputX.info()->set_data_layout(armcomputetensorutils::ConvertDataLayout(m_Data.m_Parameters.m_DataLayoutX));
136 inputY.info()->set_data_layout(armcomputetensorutils::ConvertDataLayout(m_Data.m_Parameters.m_DataLayoutY));
137
138 arm_compute::TensorInfo aclPermutedXInfo = arm_compute::TensorInfo();
139 arm_compute::TensorInfo aclPermutedYInfo = arm_compute::TensorInfo();
140
141 if (descriptor.m_Parameters.m_TransposeX == true)
142 {
143 armnn::PermutationVector permutationXVector
144 = GeneratePermutationVectorOnLastTwoDimensions(info.m_InputTensorInfos[0].GetNumDimensions());
145 const TensorInfo permutedXInfo = armnnUtils::Permuted(info.m_InputTensorInfos[0], permutationXVector);
146 const auto aclPermutationXVector = armcomputetensorutils::BuildArmComputePermutationVector(permutationXVector);
147 armcomputetensorutils::BuildArmComputeTensor(m_PermutedTensorX, permutedXInfo);
148 armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_PermutedTensorX);
149
150 auto permuteLayerX = std::make_unique<arm_compute::CLPermute>();
151 permuteLayerX->configure(clCompileContext,
152 &inputX,
153 &m_PermutedTensorX,
154 aclPermutationXVector);
155 m_PermuteLayerX.reset(permuteLayerX.release());
156 }
157
158 if (descriptor.m_Parameters.m_TransposeY == true)
159 {
160 armnn::PermutationVector permutationYVector
161 = GeneratePermutationVectorOnLastTwoDimensions(info.m_InputTensorInfos[0].GetNumDimensions());
162 const TensorInfo permutedYInfo = armnnUtils::Permuted(info.m_InputTensorInfos[0], permutationYVector);
163 const auto aclPermutationYVector = armcomputetensorutils::BuildArmComputePermutationVector(permutationYVector);
164 armcomputetensorutils::BuildArmComputeTensor(m_PermutedTensorY, permutedYInfo);
165 armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_PermutedTensorY);
166
167 std::unique_ptr<arm_compute::CLPermute> permuteLayerY(new arm_compute::CLPermute());
168 permuteLayerY->configure(clCompileContext,
169 &inputY,
170 &m_PermutedTensorY,
171 aclPermutationYVector);
172 m_PermuteLayerY.reset(permuteLayerY.release());
173 }
174
175 const arm_compute::GEMMInfo& gemm_info = arm_compute::GEMMInfo(false, // is inputX reshaped
176 false, // is inputY reshaped
177 false); // is inputY reshaped only 1st run
178 auto gemmLayer = std::make_unique<arm_compute::CLGEMM>();
179 gemmLayer->configure(clCompileContext,
180 descriptor.m_Parameters.m_TransposeX ? &m_PermutedTensorX : &inputX,
181 descriptor.m_Parameters.m_TransposeY ? &m_PermutedTensorY : &inputY,
182 nullptr,
183 &output,
184 1.0,
185 0,
186 gemm_info);
187 m_GEMMLayer.reset(gemmLayer.release());
188}
189
190void ClBatchMatMulWorkload::Execute() const
191{
192 ARMNN_SCOPED_PROFILING_EVENT_CL_GUID("ClBatchMatMulWorkload_Execute", this->GetGuid());
193 if (m_PermuteLayerX)
194 {
195 m_PermuteLayerX->run();
196 }
197 if (m_PermuteLayerY)
198 {
199 m_PermuteLayerY->run();
200 }
201 m_GEMMLayer->run();
202}
203} //namespace armnn