blob: 9b22033bd10e0735e88257ecb10c6c8f44c54c24 [file] [log] [blame]
Teresa Charlin0f86ecf2022-10-13 15:47:08 +01001//
Teresa Charlin1fe6c812022-11-01 15:59:50 +00002// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
Teresa Charlin0f86ecf2022-10-13 15:47:08 +01003// SPDX-License-Identifier: MIT
4//
5
6#include "NeonBatchMatMulWorkload.hpp"
7
8#include "NeonWorkloadUtils.hpp"
9
10#include <armnn/utility/PolymorphicDowncast.hpp>
Teresa Charlin1fe6c812022-11-01 15:59:50 +000011#include <aclCommon/ArmComputeUtils.hpp>
Teresa Charlin0f86ecf2022-10-13 15:47:08 +010012
13#include <backendsCommon/WorkloadUtils.hpp>
14
Teresa Charlin0f86ecf2022-10-13 15:47:08 +010015namespace armnn
16{
Teresa Charlin1fe6c812022-11-01 15:59:50 +000017arm_compute::Status NeonBatchMatMulValidate(const TensorInfo& inputInfoX,
18 const TensorInfo& inputInfoY,
19 const TensorInfo& outputInfo,
20 const BatchMatMulDescriptor& descriptor,
21 const bool isFastMathEnabled,
22 const ActivationDescriptor* activationDescriptor)
Teresa Charlin0f86ecf2022-10-13 15:47:08 +010023{
24 if (descriptor.m_AdjointX || descriptor.m_AdjointY )
25 {
26 throw Exception("Support for adjoint not implemented.");
27 }
28 if (descriptor.m_DataLayoutX != armnn::DataLayout::NCHW || descriptor.m_DataLayoutY != armnn::DataLayout::NCHW )
29 {
30 throw Exception("Only supported the MatMul in the last 2 dimensions");
31 }
32
Teresa Charlin1fe6c812022-11-01 15:59:50 +000033 arm_compute::TensorInfo aclInputInfoX = armcomputetensorutils::BuildArmComputeTensorInfo(inputInfoX);
34 arm_compute::TensorInfo aclInputInfoY = armcomputetensorutils::BuildArmComputeTensorInfo(inputInfoY);
35 arm_compute::TensorInfo aclOutputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(outputInfo);
Teresa Charlin0f86ecf2022-10-13 15:47:08 +010036
Teresa Charlin1fe6c812022-11-01 15:59:50 +000037 // GeMM dispatches kernel handles dynamic inputs differently to static so this flag needs to be set
38 aclInputInfoX.set_are_values_constant(false);
39 aclInputInfoY.set_are_values_constant(false);
Teresa Charlin0f86ecf2022-10-13 15:47:08 +010040
Teresa Charlin1fe6c812022-11-01 15:59:50 +000041 const arm_compute::ActivationLayerInfo activationInfo = ConvertActivationDescriptorToAclActivationLayerInfo(
42 activationDescriptor);
Teresa Charlin0f86ecf2022-10-13 15:47:08 +010043
Teresa Charlin1fe6c812022-11-01 15:59:50 +000044 arm_compute::MatMulInfo matMulInfo;
45 matMulInfo.adj_lhs(descriptor.m_TransposeX);
46 matMulInfo.adj_rhs(descriptor.m_TransposeY);
47 matMulInfo.fused_activation(activationInfo);
Teresa Charlin0f86ecf2022-10-13 15:47:08 +010048
Teresa Charlin1fe6c812022-11-01 15:59:50 +000049 arm_compute::CpuMatMulSettings settings;
50 settings.fast_math(isFastMathEnabled);
Teresa Charlin0f86ecf2022-10-13 15:47:08 +010051
Teresa Charlin1fe6c812022-11-01 15:59:50 +000052 return arm_compute::NEMatMul::validate(&aclInputInfoX, &aclInputInfoY, &aclOutputInfo, matMulInfo, settings);
Teresa Charlin0f86ecf2022-10-13 15:47:08 +010053}
54
Teresa Charlin1fe6c812022-11-01 15:59:50 +000055NeonBatchMatMulWorkload::NeonBatchMatMulWorkload(const BatchMatMulQueueDescriptor& descriptor,
56 const WorkloadInfo& info,
57 const bool isFastMathEnabled)
Teresa Charlin0f86ecf2022-10-13 15:47:08 +010058 : NeonBaseWorkload<BatchMatMulQueueDescriptor>(descriptor, info)
59{
60 if (descriptor.m_Parameters.m_AdjointX || descriptor.m_Parameters.m_AdjointY )
61 {
62 throw Exception("Support for adjoint not implemented.");
63 }
Teresa Charlin1fe6c812022-11-01 15:59:50 +000064 if (descriptor.m_Parameters.m_DataLayoutX != armnn::DataLayout::NCHW
65 || descriptor.m_Parameters.m_DataLayoutY != armnn::DataLayout::NCHW )
Teresa Charlin0f86ecf2022-10-13 15:47:08 +010066 {
67 throw Exception("Only supported the MatMul in the last 2 dimensions");
68 }
69
Teresa Charlin0f86ecf2022-10-13 15:47:08 +010070 m_Data.ValidateInputsOutputs("NeonBatchMatMulWorkload", 2, 1);
71
72 arm_compute::ITensor& inputX = PolymorphicDowncast<IAclTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
73 arm_compute::ITensor& inputY = PolymorphicDowncast<IAclTensorHandle*>(m_Data.m_Inputs[1])->GetTensor();
Teresa Charlin1fe6c812022-11-01 15:59:50 +000074 arm_compute::ITensor& output = PolymorphicDowncast<IAclTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
Teresa Charlin0f86ecf2022-10-13 15:47:08 +010075
Teresa Charlin1fe6c812022-11-01 15:59:50 +000076 // GeMM dispatches kernel handles dynamic inputs differently to static so this flag needs to be set
77 inputX.info()->set_are_values_constant(false);
78 inputY.info()->set_are_values_constant(false);
Teresa Charlin0f86ecf2022-10-13 15:47:08 +010079
Teresa Charlin1fe6c812022-11-01 15:59:50 +000080 const arm_compute::ActivationLayerInfo activationInfo = ConvertAdditionalInfoToAclActivationLayerInfo(descriptor);
Teresa Charlin0f86ecf2022-10-13 15:47:08 +010081
Teresa Charlin1fe6c812022-11-01 15:59:50 +000082 arm_compute::MatMulInfo matMulInfo;
83 matMulInfo.adj_lhs(descriptor.m_Parameters.m_TransposeX);
84 matMulInfo.adj_rhs(descriptor.m_Parameters.m_TransposeY);
85 matMulInfo.fused_activation(activationInfo);
Teresa Charlin0f86ecf2022-10-13 15:47:08 +010086
Teresa Charlin1fe6c812022-11-01 15:59:50 +000087 arm_compute::CpuMatMulSettings settings;
88 settings.fast_math(isFastMathEnabled);
Teresa Charlin0f86ecf2022-10-13 15:47:08 +010089
Teresa Charlin1fe6c812022-11-01 15:59:50 +000090 m_MatMulLayer.configure(&inputX, &inputY, &output, matMulInfo, settings);
Teresa Charlin0f86ecf2022-10-13 15:47:08 +010091
Teresa Charlin1fe6c812022-11-01 15:59:50 +000092 // Report Profiling Details
93 WorkloadInfo detailsInfo;
94 detailsInfo.m_InputTensorInfos = info.m_InputTensorInfos;
95 detailsInfo.m_OutputTensorInfos = info.m_OutputTensorInfos;
96 ARMNN_REPORT_PROFILING_WORKLOAD_DESC("NeonBatchMatMulWorkload_Construct",
97 descriptor.m_Parameters,
98 detailsInfo,
99 GetGuid());
Teresa Charlin0f86ecf2022-10-13 15:47:08 +0100100}
101
102void NeonBatchMatMulWorkload::Execute() const
103{
104 ARMNN_SCOPED_PROFILING_EVENT_NEON_GUID("NeonBatchMatMulWorkload_Execute", this->GetGuid());
Teresa Charlin1fe6c812022-11-01 15:59:50 +0000105 m_MatMulLayer.run();
Teresa Charlin0f86ecf2022-10-13 15:47:08 +0100106}
107} //namespace armnn