blob: 5dd542e004e9df4ffd0b4c66c158f1d0ac24d9f7 [file] [log] [blame]
Teresa Charlin94916a52022-10-19 08:48:07 +01001//
Teresa Charlinc074cdb2023-01-13 18:44:00 +00002// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
Teresa Charlin94916a52022-10-19 08:48:07 +01003// SPDX-License-Identifier: MIT
4//
5
6#include "ClBatchMatMulWorkload.hpp"
7
8#include "ClWorkloadUtils.hpp"
9
10#include <aclCommon/ArmComputeTensorUtils.hpp>
11#include <aclCommon/ArmComputeUtils.hpp>
12
13#include <armnn/utility/PolymorphicDowncast.hpp>
14
Teresa Charlin94916a52022-10-19 08:48:07 +010015#include <backendsCommon/WorkloadUtils.hpp>
16
17#include <cl/ClTensorHandle.hpp>
18
Nikhil Raj038f52b2023-07-31 10:06:32 +010019#include <arm_compute/function_info/MatMulInfo.h>
Teresa Charlin94916a52022-10-19 08:48:07 +010020
21namespace armnn
22{
Mike Kelly0e3fe102023-01-23 19:32:06 +000023
Teresa Charlin97a3aef2023-01-10 10:32:51 +000024arm_compute::Status ClBatchMatMulValidate(const TensorInfo& inputInfoX,
25 const TensorInfo& inputInfoY,
26 const TensorInfo& outputInfo,
27 const BatchMatMulDescriptor& descriptor,
28 const ActivationDescriptor* activationDescriptor)
Teresa Charlin94916a52022-10-19 08:48:07 +010029{
30 if (descriptor.m_AdjointX || descriptor.m_AdjointY )
31 {
32 throw Exception("Support for adjoint not implemented.");
33 }
34 if (descriptor.m_DataLayoutX != armnn::DataLayout::NCHW || descriptor.m_DataLayoutY != armnn::DataLayout::NCHW )
35 {
36 throw Exception("Only supported the MatMul in the last 2 dimensions");
37 }
38
Teresa Charlin97a3aef2023-01-10 10:32:51 +000039 arm_compute::TensorInfo aclInputInfoX = armcomputetensorutils::BuildArmComputeTensorInfo(inputInfoX);
40 arm_compute::TensorInfo aclInputInfoY = armcomputetensorutils::BuildArmComputeTensorInfo(inputInfoY);
41 const arm_compute::TensorInfo aclOutputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(outputInfo);
Teresa Charlin94916a52022-10-19 08:48:07 +010042
Teresa Charlin97a3aef2023-01-10 10:32:51 +000043 // GeMM dispatches kernel handles dynamic inputs differently to static so this flag needs to be set
44 aclInputInfoX.set_are_values_constant(false);
45 aclInputInfoY.set_are_values_constant(false);
Teresa Charlin94916a52022-10-19 08:48:07 +010046
Teresa Charlin97a3aef2023-01-10 10:32:51 +000047 const arm_compute::ActivationLayerInfo activationInfo = ConvertActivationDescriptorToAclActivationLayerInfo(
48 activationDescriptor);
Teresa Charlin94916a52022-10-19 08:48:07 +010049
Teresa Charlin97a3aef2023-01-10 10:32:51 +000050 arm_compute::MatMulInfo matMulInfo;
51 matMulInfo.adj_lhs(descriptor.m_TransposeX);
52 matMulInfo.adj_rhs(descriptor.m_TransposeY);
Mike Kelly0e3fe102023-01-23 19:32:06 +000053
Nikhil Rajd29d09d2023-06-26 11:52:40 +010054 return arm_compute::CLMatMul::validate(&aclInputInfoX, &aclInputInfoY, &aclOutputInfo, matMulInfo, activationInfo);
Teresa Charlin94916a52022-10-19 08:48:07 +010055}
56
57ClBatchMatMulWorkload::ClBatchMatMulWorkload(const BatchMatMulQueueDescriptor& descriptor,
58 const WorkloadInfo& info,
59 const arm_compute::CLCompileContext& clCompileContext)
60 : ClBaseWorkload<BatchMatMulQueueDescriptor>(descriptor, info)
61{
62 // Report Profiling Details
63 ARMNN_REPORT_PROFILING_WORKLOAD_DESC("ClBatchMatMulWorkload_Construct",
64 descriptor.m_Parameters,
65 info,
66 this->GetGuid());
67
68 if (descriptor.m_Parameters.m_AdjointX || descriptor.m_Parameters.m_AdjointY )
69 {
70 throw Exception("Support for adjoint not implemented.");
71 }
72 if (descriptor.m_Parameters.m_DataLayoutX != armnn::DataLayout::NCHW ||
73 descriptor.m_Parameters.m_DataLayoutY != armnn::DataLayout::NCHW )
74 {
75 throw Exception("Only supported the MatMul in the last 2 dimensions");
76 }
77
78 m_Data.ValidateInputsOutputs("ClBatchMatMulWorkload", 2, 1);
79
Teresa Charlin97a3aef2023-01-10 10:32:51 +000080 arm_compute::ICLTensor& inputX = PolymorphicDowncast<ClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
81 arm_compute::ICLTensor& inputY = PolymorphicDowncast<ClTensorHandle*>(m_Data.m_Inputs[1])->GetTensor();
82 auto outputHandle = PolymorphicDowncast<ClTensorHandle*>(m_Data.m_Outputs[0]);
83 arm_compute::ICLTensor& output = outputHandle->GetTensor();
Teresa Charlin94916a52022-10-19 08:48:07 +010084
Teresa Charlin97a3aef2023-01-10 10:32:51 +000085 // GeMM dispatches kernel handles dynamic inputs differently to static so this flag needs to be set
86 inputX.info()->set_are_values_constant(false);
87 inputY.info()->set_are_values_constant(false);
Teresa Charlin94916a52022-10-19 08:48:07 +010088
Teresa Charlin97a3aef2023-01-10 10:32:51 +000089 const arm_compute::ActivationLayerInfo activationInfo = ConvertAdditionalInfoToAclActivationLayerInfo(descriptor);
Teresa Charlin94916a52022-10-19 08:48:07 +010090
Teresa Charlin97a3aef2023-01-10 10:32:51 +000091 arm_compute::MatMulInfo matMulInfo;
92 matMulInfo.adj_lhs(descriptor.m_Parameters.m_TransposeX);
93 matMulInfo.adj_rhs(descriptor.m_Parameters.m_TransposeY);
Mike Kelly0e3fe102023-01-23 19:32:06 +000094
Nikhil Rajd29d09d2023-06-26 11:52:40 +010095 arm_compute::GpuMatMulSettings settings;
96
97 m_MatMulLayer.configure(clCompileContext, &inputX, &inputY, &output, matMulInfo, settings, activationInfo);
Teresa Charlin94916a52022-10-19 08:48:07 +010098
Teresa Charlin97a3aef2023-01-10 10:32:51 +000099 // Report Profiling Details
100 WorkloadInfo detailsInfo;
101 detailsInfo.m_InputTensorInfos = info.m_InputTensorInfos;
102 detailsInfo.m_OutputTensorInfos = info.m_OutputTensorInfos;
103 ARMNN_REPORT_PROFILING_WORKLOAD_DESC("ClBatchMatMulWorkload_Construct",
104 descriptor.m_Parameters,
105 detailsInfo,
106 GetGuid());
Teresa Charlin94916a52022-10-19 08:48:07 +0100107}
108
109void ClBatchMatMulWorkload::Execute() const
110{
Mike Kelly7cbe7812023-07-25 17:37:33 +0100111 ARMNN_SCOPED_PROFILING_EVENT_CL_NAME_GUID("ClBatchMatMulWorkload_Execute");
Teresa Charlin97a3aef2023-01-10 10:32:51 +0000112 RunClFunction(m_MatMulLayer, CHECK_LOCATION());
Teresa Charlin94916a52022-10-19 08:48:07 +0100113}
114} //namespace armnn