blob: 59fc20afade56b0c756c618b54b1c46c849c9823 [file] [log] [blame]
Teresa Charlinbd22c7d2022-04-26 18:14:12 +01001//
Colm Donelanb4ef1632024-02-01 15:00:43 +00002// Copyright © 2022-2024 Arm Ltd and Contributors. All rights reserved.
Teresa Charlinbd22c7d2022-04-26 18:14:12 +01003// SPDX-License-Identifier: MIT
4//
5
6#include "NeonGatherNdWorkload.hpp"
7#include "NeonWorkloadUtils.hpp"
8#include <armnn/utility/PolymorphicDowncast.hpp>
9#include <aclCommon/ArmComputeUtils.hpp>
10#include "backendsCommon/WorkloadUtils.hpp"
11
12namespace armnn
13{
Teresa Charlinaa805102022-05-06 10:49:03 +010014arm_compute::Status NeonGatherNdWorkloadValidate(const TensorInfo& paramsInfo,
Teresa Charlinbd22c7d2022-04-26 18:14:12 +010015 const TensorInfo& indicesInfo,
16 const TensorInfo& outputInfo)
17{
18 // Calculate ND, K, W, C.
Teresa Charlinaa805102022-05-06 10:49:03 +010019 std::map<std::string, unsigned int> keyIndices = CalculateGatherNdKeyIndices(paramsInfo, indicesInfo);
Teresa Charlinbd22c7d2022-04-26 18:14:12 +010020
Teresa Charlinaa805102022-05-06 10:49:03 +010021 /// Validate Mul
22 // Indices with shape { W, ND }
23 armnn::TensorInfo indices_W_ND_Info = indicesInfo;
24 indices_W_ND_Info.SetShape({ keyIndices["W"], keyIndices["ND"] });
25 const arm_compute::TensorInfo aclIndicesInfo = BuildArmComputeTensorInfo(indices_W_ND_Info);
26
27 // Flattened coefficients with shape { ND }
28 armnn::TensorInfo flattenedCoeff_Info = indicesInfo;
29 flattenedCoeff_Info.SetShape({ keyIndices["ND"] });
30 const arm_compute::TensorInfo aclFlattenedCoeffInfo = BuildArmComputeTensorInfo(flattenedCoeff_Info);
31
32 // Output of Mul with shape { W, ND }
33 const arm_compute::TensorInfo aclOutputMulInfo = BuildArmComputeTensorInfo(indices_W_ND_Info);
34
35 auto statusMul = arm_compute::NEPixelWiseMultiplication::validate(&aclIndicesInfo,
36 &aclFlattenedCoeffInfo,
37 &aclOutputMulInfo,
38 1.0f,
39 arm_compute::ConvertPolicy::WRAP,
40 arm_compute::RoundingPolicy::TO_ZERO,
41 arm_compute::ActivationLayerInfo());
42
43 /// Validate ReduceSum
44 // Flattened indices with shape { W }
45 armnn::TensorInfo flattenedIndices_Info = indicesInfo;
46 flattenedIndices_Info.SetShape({ keyIndices["W"] });
47 const arm_compute::TensorInfo aclFlattenedIndicesInfo = BuildArmComputeTensorInfo(flattenedIndices_Info);
48
49 const std::vector<unsigned int> armnnReduceAxes(1, 1);
50 arm_compute::Coordinates coords = BuildArmComputeReductionCoordinates(aclOutputMulInfo.num_dimensions(),
51 indices_W_ND_Info.GetNumDimensions(),
52 armnnReduceAxes);
53
54 auto statusReduceSum = arm_compute::NEReductionOperation::validate(&aclOutputMulInfo,
55 &aclFlattenedIndicesInfo,
56 static_cast<unsigned int>(coords[0]),
57 arm_compute::ReductionOperation::SUM,
Teresa Charlinb190da22022-05-09 16:25:47 +010058 false);
Teresa Charlinaa805102022-05-06 10:49:03 +010059
60 /// Validate Gather
61 // Params with shape { K, C }
62 armnn::TensorInfo params_K_C_Info = paramsInfo;
Teresa Charlinbd22c7d2022-04-26 18:14:12 +010063 params_K_C_Info.SetShape({ keyIndices["K"], keyIndices["C"] });
Teresa Charlinaa805102022-05-06 10:49:03 +010064 const arm_compute::TensorInfo aclParamsInfo = BuildArmComputeTensorInfo(params_K_C_Info);
Teresa Charlinbd22c7d2022-04-26 18:14:12 +010065
Teresa Charlinaa805102022-05-06 10:49:03 +010066 // Output of gather with shape { W, C }
Teresa Charlinbd22c7d2022-04-26 18:14:12 +010067 armnn::TensorInfo outputGather_Info = outputInfo;
68 outputGather_Info.SetShape({ keyIndices["W"], keyIndices["C"] });
Teresa Charlinaa805102022-05-06 10:49:03 +010069 const arm_compute::TensorInfo aclOutputGatherInfo = BuildArmComputeTensorInfo(outputGather_Info);
Teresa Charlinbd22c7d2022-04-26 18:14:12 +010070
71 auto aclAxis = ComputeAclAxis(0, params_K_C_Info);
Teresa Charlinaa805102022-05-06 10:49:03 +010072 auto statusGather =
73 arm_compute::NEGather::validate(&aclParamsInfo, &aclFlattenedIndicesInfo, &aclOutputGatherInfo, aclAxis);
74
75 /// Validate Reshape
76 const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(outputInfo);
77
78 auto statusReshape = arm_compute::NEReshapeLayer::validate(&aclOutputGatherInfo, &aclOutputInfo);
79
80 /// Return OK if all the layers are valid
81 auto okCode = arm_compute::ErrorCode::OK;
82 if (statusMul.error_code() == okCode &&
83 statusReduceSum.error_code() == okCode &&
84 statusGather.error_code() == okCode &&
85 statusReshape.error_code() == okCode)
86 {
87 return arm_compute::Status(arm_compute::ErrorCode::OK,
88 "All GatherND layers validate status OK.");
89 }
90 else
91 {
92 return arm_compute::Status(arm_compute::ErrorCode::RUNTIME_ERROR,
93 "GatherND layer validate status failed.");
94 }
Teresa Charlinbd22c7d2022-04-26 18:14:12 +010095}
96
97NeonGatherNdWorkload::NeonGatherNdWorkload(const GatherNdQueueDescriptor& descriptor,
98 const WorkloadInfo& info)
99 : NeonBaseWorkload<GatherNdQueueDescriptor>(descriptor, info)
100{
101 m_Data.ValidateInputsOutputs("NeonGatherNdWorkload", 2, 1);
102
103 TensorInfo paramsInfo = info.m_InputTensorInfos[0];
104 TensorInfo indicesInfo = info.m_InputTensorInfos[1];
105 TensorInfo outputInfo = info.m_OutputTensorInfos[0];
106
107 arm_compute::ITensor& input = PolymorphicDowncast<IAclTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
108 arm_compute::ITensor& indices = PolymorphicDowncast<IAclTensorHandle*>(m_Data.m_Inputs[1])->GetTensor();
109 arm_compute::ITensor& output = PolymorphicDowncast<IAclTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
110
111 // Calculate ND, K, W, C.
112 std::map<std::string, unsigned int> keyIndices = CalculateGatherNdKeyIndices(paramsInfo, indicesInfo);
113
114 /// Calculate flattened indices: m_FlattenedIndices = indices * m_FlattenedCoeff.
115 /// This could be done using MatMul instead of multiplication followed by reduce sum operation,
116 /// but GeMM does not support s32 at the moment.
117
118 // Prepare the tensor to store the output of the reduce_sum operation
119 armnn::TensorInfo flattenedIndices_Info = indicesInfo;
120 flattenedIndices_Info.SetShape({ keyIndices["W"] });
121 BuildArmComputeTensor(m_FlattenedIndices, flattenedIndices_Info);
122 armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_FlattenedIndices);
123
124 // Reshape indices into { W, ND }
125 indices.info()->set_tensor_shape(BuildArmComputeTensorShape({ keyIndices["W"], keyIndices["ND"] }));
126
127 // Calculate the m_FlattenedCoeff
128 TensorShape paramsShape = paramsInfo.GetShape();
Teresa Charlinb190da22022-05-09 16:25:47 +0100129 std::vector<int32_t> flattenedCoeff(keyIndices["ND"], 1);
Teresa Charlinbd22c7d2022-04-26 18:14:12 +0100130 for (unsigned int i = 1; i < keyIndices["ND"]; ++i)
131 {
Teresa Charlinb190da22022-05-09 16:25:47 +0100132 flattenedCoeff[i - 1] = static_cast<int32_t>(paramsShape[i]);
Teresa Charlinbd22c7d2022-04-26 18:14:12 +0100133 }
134 for (unsigned int i = keyIndices["ND"] - 1; i > 0; --i)
135 {
136 flattenedCoeff[i - 1] *= flattenedCoeff[i];
137 }
138 armnn::TensorInfo flattenedCoeff_Info = indicesInfo;
139 flattenedCoeff_Info.SetShape({ keyIndices["ND"] });
140 BuildArmComputeTensor(m_FlattenedCoeff, flattenedCoeff_Info);
141 armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_FlattenedCoeff);
Teresa Charlinb190da22022-05-09 16:25:47 +0100142 CopyArmComputeITensorData<int32_t>(flattenedCoeff.data(), m_FlattenedCoeff);
Teresa Charlinbd22c7d2022-04-26 18:14:12 +0100143
144 // Prepare the tensor to store the output of the multiplication
145 armnn::TensorInfo outputMul_Info = indicesInfo;
146 outputMul_Info.SetShape({ keyIndices["W"], keyIndices["ND"] });
Teresa Charlinb190da22022-05-09 16:25:47 +0100147 BuildArmComputeTensor(m_OutputMul, outputMul_Info);
148 armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_OutputMul);
Teresa Charlinbd22c7d2022-04-26 18:14:12 +0100149
150 // Multiply
Teresa Charlinbd22c7d2022-04-26 18:14:12 +0100151 m_MulLayer.configure(&indices,
152 &m_FlattenedCoeff,
Teresa Charlinb190da22022-05-09 16:25:47 +0100153 &m_OutputMul,
Teresa Charlinbd22c7d2022-04-26 18:14:12 +0100154 1.0f,
Teresa Charlinaa805102022-05-06 10:49:03 +0100155 arm_compute::ConvertPolicy::WRAP,
Teresa Charlinbd22c7d2022-04-26 18:14:12 +0100156 arm_compute::RoundingPolicy::TO_ZERO,
157 arm_compute::ActivationLayerInfo());
158
159 // Reduce Sum
160 const std::vector<unsigned int> armnnReduceAxes(1, 1);
Teresa Charlinb190da22022-05-09 16:25:47 +0100161 arm_compute::Coordinates coords = BuildArmComputeReductionCoordinates(m_OutputMul.info()->num_dimensions(),
Teresa Charlinbd22c7d2022-04-26 18:14:12 +0100162 outputMul_Info.GetNumDimensions(),
163 armnnReduceAxes);
Teresa Charlinb190da22022-05-09 16:25:47 +0100164 m_ReduceSumLayer.configure(&m_OutputMul,
Teresa Charlinbd22c7d2022-04-26 18:14:12 +0100165 &m_FlattenedIndices,
166 static_cast<unsigned int>(coords[0]),
167 arm_compute::ReductionOperation::SUM,
168 false);
169
170 /// Call Gather with adequate shapes
171 // Reshape params into { K, C }
172 paramsInfo.SetShape({ keyIndices["K"], keyIndices["C"] });
173 input.info()->set_tensor_shape(BuildArmComputeTensorShape(paramsInfo.GetShape()));
174
175 // Reshape output to have the shape given by gather { W, C }
176 // (the original outputInfo has the shape given by gatherNd)
177 armnn::TensorInfo outputGather_Info = outputInfo;
178 outputGather_Info.SetShape({ keyIndices["W"], keyIndices["C"] });
Teresa Charlinb190da22022-05-09 16:25:47 +0100179 BuildArmComputeTensor(m_OutputGather, outputGather_Info);
180 armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_OutputGather);
Teresa Charlinbd22c7d2022-04-26 18:14:12 +0100181
Teresa Charlinb190da22022-05-09 16:25:47 +0100182 m_GatherLayer.configure(&input, &m_FlattenedIndices, &m_OutputGather, ComputeAclAxis(0, paramsInfo));
Teresa Charlinbd22c7d2022-04-26 18:14:12 +0100183
184 // Reshape output to the original output shape
Teresa Charlinb190da22022-05-09 16:25:47 +0100185 m_ReshapeLayer.configure(&m_OutputGather, &output);
Teresa Charlinbd22c7d2022-04-26 18:14:12 +0100186}
187
188void NeonGatherNdWorkload::Execute() const
189{
Mike Kelly7cbe7812023-07-25 17:37:33 +0100190 ARMNN_SCOPED_PROFILING_EVENT_NEON_NAME_GUID("NeonGatherNdWorkload_Execute");
Teresa Charlinbd22c7d2022-04-26 18:14:12 +0100191 m_MulLayer.run();
192 m_ReduceSumLayer.run();
193 m_GatherLayer.run();
194 m_ReshapeLayer.run();
195}
196} //namespace armnn