blob: bd0fd516179e70ee39fca97e254a887bdbc72ef7 [file] [log] [blame]
Teresa Charlin94916a52022-10-19 08:48:07 +01001//
Teresa Charlinc074cdb2023-01-13 18:44:00 +00002// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
Teresa Charlin94916a52022-10-19 08:48:07 +01003// SPDX-License-Identifier: MIT
4//
5
6#include "ClBatchMatMulWorkload.hpp"
7
8#include "ClWorkloadUtils.hpp"
9
10#include <aclCommon/ArmComputeTensorUtils.hpp>
11#include <aclCommon/ArmComputeUtils.hpp>
12
13#include <armnn/utility/PolymorphicDowncast.hpp>
14
Teresa Charlin94916a52022-10-19 08:48:07 +010015#include <backendsCommon/WorkloadUtils.hpp>
16
17#include <cl/ClTensorHandle.hpp>
18
Teresa Charlin94916a52022-10-19 08:48:07 +010019
20namespace armnn
21{
Mike Kelly0e3fe102023-01-23 19:32:06 +000022
Teresa Charlin97a3aef2023-01-10 10:32:51 +000023arm_compute::Status ClBatchMatMulValidate(const TensorInfo& inputInfoX,
24 const TensorInfo& inputInfoY,
25 const TensorInfo& outputInfo,
26 const BatchMatMulDescriptor& descriptor,
27 const ActivationDescriptor* activationDescriptor)
Teresa Charlin94916a52022-10-19 08:48:07 +010028{
29 if (descriptor.m_AdjointX || descriptor.m_AdjointY )
30 {
31 throw Exception("Support for adjoint not implemented.");
32 }
33 if (descriptor.m_DataLayoutX != armnn::DataLayout::NCHW || descriptor.m_DataLayoutY != armnn::DataLayout::NCHW )
34 {
35 throw Exception("Only supported the MatMul in the last 2 dimensions");
36 }
37
Teresa Charlin97a3aef2023-01-10 10:32:51 +000038 arm_compute::TensorInfo aclInputInfoX = armcomputetensorutils::BuildArmComputeTensorInfo(inputInfoX);
39 arm_compute::TensorInfo aclInputInfoY = armcomputetensorutils::BuildArmComputeTensorInfo(inputInfoY);
40 const arm_compute::TensorInfo aclOutputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(outputInfo);
Teresa Charlin94916a52022-10-19 08:48:07 +010041
Teresa Charlin97a3aef2023-01-10 10:32:51 +000042 // GeMM dispatches kernel handles dynamic inputs differently to static so this flag needs to be set
43 aclInputInfoX.set_are_values_constant(false);
44 aclInputInfoY.set_are_values_constant(false);
Teresa Charlin94916a52022-10-19 08:48:07 +010045
Teresa Charlin97a3aef2023-01-10 10:32:51 +000046 const arm_compute::ActivationLayerInfo activationInfo = ConvertActivationDescriptorToAclActivationLayerInfo(
47 activationDescriptor);
Teresa Charlin94916a52022-10-19 08:48:07 +010048
Teresa Charlin97a3aef2023-01-10 10:32:51 +000049 arm_compute::MatMulInfo matMulInfo;
50 matMulInfo.adj_lhs(descriptor.m_TransposeX);
51 matMulInfo.adj_rhs(descriptor.m_TransposeY);
52 matMulInfo.fused_activation(activationInfo);
Mike Kelly0e3fe102023-01-23 19:32:06 +000053
Teresa Charlin97a3aef2023-01-10 10:32:51 +000054 return arm_compute::CLMatMul::validate(&aclInputInfoX, &aclInputInfoY, &aclOutputInfo, matMulInfo);
Teresa Charlin94916a52022-10-19 08:48:07 +010055}
56
57ClBatchMatMulWorkload::ClBatchMatMulWorkload(const BatchMatMulQueueDescriptor& descriptor,
58 const WorkloadInfo& info,
59 const arm_compute::CLCompileContext& clCompileContext)
60 : ClBaseWorkload<BatchMatMulQueueDescriptor>(descriptor, info)
61{
62 // Report Profiling Details
63 ARMNN_REPORT_PROFILING_WORKLOAD_DESC("ClBatchMatMulWorkload_Construct",
64 descriptor.m_Parameters,
65 info,
66 this->GetGuid());
67
68 if (descriptor.m_Parameters.m_AdjointX || descriptor.m_Parameters.m_AdjointY )
69 {
70 throw Exception("Support for adjoint not implemented.");
71 }
72 if (descriptor.m_Parameters.m_DataLayoutX != armnn::DataLayout::NCHW ||
73 descriptor.m_Parameters.m_DataLayoutY != armnn::DataLayout::NCHW )
74 {
75 throw Exception("Only supported the MatMul in the last 2 dimensions");
76 }
77
78 m_Data.ValidateInputsOutputs("ClBatchMatMulWorkload", 2, 1);
79
Teresa Charlin97a3aef2023-01-10 10:32:51 +000080 arm_compute::ICLTensor& inputX = PolymorphicDowncast<ClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
81 arm_compute::ICLTensor& inputY = PolymorphicDowncast<ClTensorHandle*>(m_Data.m_Inputs[1])->GetTensor();
82 auto outputHandle = PolymorphicDowncast<ClTensorHandle*>(m_Data.m_Outputs[0]);
83 arm_compute::ICLTensor& output = outputHandle->GetTensor();
Teresa Charlin94916a52022-10-19 08:48:07 +010084
Teresa Charlin97a3aef2023-01-10 10:32:51 +000085 // GeMM dispatches kernel handles dynamic inputs differently to static so this flag needs to be set
86 inputX.info()->set_are_values_constant(false);
87 inputY.info()->set_are_values_constant(false);
Teresa Charlin94916a52022-10-19 08:48:07 +010088
Teresa Charlin97a3aef2023-01-10 10:32:51 +000089 const arm_compute::ActivationLayerInfo activationInfo = ConvertAdditionalInfoToAclActivationLayerInfo(descriptor);
Teresa Charlin94916a52022-10-19 08:48:07 +010090
Teresa Charlin97a3aef2023-01-10 10:32:51 +000091 arm_compute::MatMulInfo matMulInfo;
92 matMulInfo.adj_lhs(descriptor.m_Parameters.m_TransposeX);
93 matMulInfo.adj_rhs(descriptor.m_Parameters.m_TransposeY);
94 matMulInfo.fused_activation(activationInfo);
Mike Kelly0e3fe102023-01-23 19:32:06 +000095
Teresa Charlin97a3aef2023-01-10 10:32:51 +000096 m_MatMulLayer.configure(clCompileContext, &inputX, &inputY, &output, matMulInfo);
Teresa Charlin94916a52022-10-19 08:48:07 +010097
Teresa Charlin97a3aef2023-01-10 10:32:51 +000098 // Report Profiling Details
99 WorkloadInfo detailsInfo;
100 detailsInfo.m_InputTensorInfos = info.m_InputTensorInfos;
101 detailsInfo.m_OutputTensorInfos = info.m_OutputTensorInfos;
102 ARMNN_REPORT_PROFILING_WORKLOAD_DESC("ClBatchMatMulWorkload_Construct",
103 descriptor.m_Parameters,
104 detailsInfo,
105 GetGuid());
Teresa Charlin94916a52022-10-19 08:48:07 +0100106}
107
108void ClBatchMatMulWorkload::Execute() const
109{
110 ARMNN_SCOPED_PROFILING_EVENT_CL_GUID("ClBatchMatMulWorkload_Execute", this->GetGuid());
Teresa Charlin97a3aef2023-01-10 10:32:51 +0000111 RunClFunction(m_MatMulLayer, CHECK_LOCATION());
Teresa Charlin94916a52022-10-19 08:48:07 +0100112}
113} //namespace armnn