blob: 628e314046d40fdb16ac1e0029a26a1c132a5fe2 [file] [log] [blame]
Teresa Charlin0f86ecf2022-10-13 15:47:08 +01001//
Teresa Charlin1fe6c812022-11-01 15:59:50 +00002// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
Teresa Charlin0f86ecf2022-10-13 15:47:08 +01003// SPDX-License-Identifier: MIT
4//
5
6#include "NeonBatchMatMulWorkload.hpp"
7
8#include "NeonWorkloadUtils.hpp"
9
10#include <armnn/utility/PolymorphicDowncast.hpp>
Teresa Charlin1fe6c812022-11-01 15:59:50 +000011#include <aclCommon/ArmComputeUtils.hpp>
Teresa Charlin0f86ecf2022-10-13 15:47:08 +010012
13#include <backendsCommon/WorkloadUtils.hpp>
14
Nikhil Raja9c5c162023-06-16 15:54:32 +010015#include <arm_compute/core/MatMulInfo.h>
16
Teresa Charlin0f86ecf2022-10-13 15:47:08 +010017namespace armnn
18{
Teresa Charlin1fe6c812022-11-01 15:59:50 +000019arm_compute::Status NeonBatchMatMulValidate(const TensorInfo& inputInfoX,
20 const TensorInfo& inputInfoY,
21 const TensorInfo& outputInfo,
22 const BatchMatMulDescriptor& descriptor,
23 const bool isFastMathEnabled,
24 const ActivationDescriptor* activationDescriptor)
Teresa Charlin0f86ecf2022-10-13 15:47:08 +010025{
26 if (descriptor.m_AdjointX || descriptor.m_AdjointY )
27 {
28 throw Exception("Support for adjoint not implemented.");
29 }
30 if (descriptor.m_DataLayoutX != armnn::DataLayout::NCHW || descriptor.m_DataLayoutY != armnn::DataLayout::NCHW )
31 {
32 throw Exception("Only supported the MatMul in the last 2 dimensions");
33 }
34
Teresa Charlin1fe6c812022-11-01 15:59:50 +000035 arm_compute::TensorInfo aclInputInfoX = armcomputetensorutils::BuildArmComputeTensorInfo(inputInfoX);
36 arm_compute::TensorInfo aclInputInfoY = armcomputetensorutils::BuildArmComputeTensorInfo(inputInfoY);
37 arm_compute::TensorInfo aclOutputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(outputInfo);
Teresa Charlin0f86ecf2022-10-13 15:47:08 +010038
Teresa Charlin1fe6c812022-11-01 15:59:50 +000039 // GeMM dispatches kernel handles dynamic inputs differently to static so this flag needs to be set
40 aclInputInfoX.set_are_values_constant(false);
41 aclInputInfoY.set_are_values_constant(false);
Teresa Charlin0f86ecf2022-10-13 15:47:08 +010042
Teresa Charlin1fe6c812022-11-01 15:59:50 +000043 const arm_compute::ActivationLayerInfo activationInfo = ConvertActivationDescriptorToAclActivationLayerInfo(
44 activationDescriptor);
Teresa Charlin0f86ecf2022-10-13 15:47:08 +010045
Teresa Charlin1fe6c812022-11-01 15:59:50 +000046 arm_compute::MatMulInfo matMulInfo;
47 matMulInfo.adj_lhs(descriptor.m_TransposeX);
48 matMulInfo.adj_rhs(descriptor.m_TransposeY);
49 matMulInfo.fused_activation(activationInfo);
Teresa Charlin0f86ecf2022-10-13 15:47:08 +010050
Teresa Charlin1fe6c812022-11-01 15:59:50 +000051 arm_compute::CpuMatMulSettings settings;
52 settings.fast_math(isFastMathEnabled);
Teresa Charlin0f86ecf2022-10-13 15:47:08 +010053
Teresa Charlin1fe6c812022-11-01 15:59:50 +000054 return arm_compute::NEMatMul::validate(&aclInputInfoX, &aclInputInfoY, &aclOutputInfo, matMulInfo, settings);
Teresa Charlin0f86ecf2022-10-13 15:47:08 +010055}
56
Teresa Charlin1fe6c812022-11-01 15:59:50 +000057NeonBatchMatMulWorkload::NeonBatchMatMulWorkload(const BatchMatMulQueueDescriptor& descriptor,
58 const WorkloadInfo& info,
59 const bool isFastMathEnabled)
Teresa Charlin0f86ecf2022-10-13 15:47:08 +010060 : NeonBaseWorkload<BatchMatMulQueueDescriptor>(descriptor, info)
61{
62 if (descriptor.m_Parameters.m_AdjointX || descriptor.m_Parameters.m_AdjointY )
63 {
64 throw Exception("Support for adjoint not implemented.");
65 }
Teresa Charlin1fe6c812022-11-01 15:59:50 +000066 if (descriptor.m_Parameters.m_DataLayoutX != armnn::DataLayout::NCHW
67 || descriptor.m_Parameters.m_DataLayoutY != armnn::DataLayout::NCHW )
Teresa Charlin0f86ecf2022-10-13 15:47:08 +010068 {
69 throw Exception("Only supported the MatMul in the last 2 dimensions");
70 }
71
Teresa Charlin0f86ecf2022-10-13 15:47:08 +010072 m_Data.ValidateInputsOutputs("NeonBatchMatMulWorkload", 2, 1);
73
74 arm_compute::ITensor& inputX = PolymorphicDowncast<IAclTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
75 arm_compute::ITensor& inputY = PolymorphicDowncast<IAclTensorHandle*>(m_Data.m_Inputs[1])->GetTensor();
Teresa Charlin1fe6c812022-11-01 15:59:50 +000076 arm_compute::ITensor& output = PolymorphicDowncast<IAclTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
Teresa Charlin0f86ecf2022-10-13 15:47:08 +010077
Teresa Charlin1fe6c812022-11-01 15:59:50 +000078 // GeMM dispatches kernel handles dynamic inputs differently to static so this flag needs to be set
79 inputX.info()->set_are_values_constant(false);
80 inputY.info()->set_are_values_constant(false);
Teresa Charlin0f86ecf2022-10-13 15:47:08 +010081
Teresa Charlin1fe6c812022-11-01 15:59:50 +000082 const arm_compute::ActivationLayerInfo activationInfo = ConvertAdditionalInfoToAclActivationLayerInfo(descriptor);
Teresa Charlin0f86ecf2022-10-13 15:47:08 +010083
Teresa Charlin1fe6c812022-11-01 15:59:50 +000084 arm_compute::MatMulInfo matMulInfo;
85 matMulInfo.adj_lhs(descriptor.m_Parameters.m_TransposeX);
86 matMulInfo.adj_rhs(descriptor.m_Parameters.m_TransposeY);
87 matMulInfo.fused_activation(activationInfo);
Teresa Charlin0f86ecf2022-10-13 15:47:08 +010088
Teresa Charlin1fe6c812022-11-01 15:59:50 +000089 arm_compute::CpuMatMulSettings settings;
90 settings.fast_math(isFastMathEnabled);
Teresa Charlin0f86ecf2022-10-13 15:47:08 +010091
Teresa Charlin1fe6c812022-11-01 15:59:50 +000092 m_MatMulLayer.configure(&inputX, &inputY, &output, matMulInfo, settings);
Teresa Charlin0f86ecf2022-10-13 15:47:08 +010093
Teresa Charlin1fe6c812022-11-01 15:59:50 +000094 // Report Profiling Details
95 WorkloadInfo detailsInfo;
96 detailsInfo.m_InputTensorInfos = info.m_InputTensorInfos;
97 detailsInfo.m_OutputTensorInfos = info.m_OutputTensorInfos;
98 ARMNN_REPORT_PROFILING_WORKLOAD_DESC("NeonBatchMatMulWorkload_Construct",
99 descriptor.m_Parameters,
100 detailsInfo,
101 GetGuid());
Teresa Charlin0f86ecf2022-10-13 15:47:08 +0100102}
103
104void NeonBatchMatMulWorkload::Execute() const
105{
106 ARMNN_SCOPED_PROFILING_EVENT_NEON_GUID("NeonBatchMatMulWorkload_Execute", this->GetGuid());
Teresa Charlin1fe6c812022-11-01 15:59:50 +0000107 m_MatMulLayer.run();
Teresa Charlin0f86ecf2022-10-13 15:47:08 +0100108}
109} //namespace armnn