blob: fa820aeb822b1958f87c18a7c0bade72af4fe3a5 [file] [log] [blame]
Teresa Charlinb2d3ec52022-04-12 22:07:09 +01001//
Mike Kelly7cbe7812023-07-25 17:37:33 +01002// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
Teresa Charlinb2d3ec52022-04-12 22:07:09 +01003// SPDX-License-Identifier: MIT
4//
5
Ciara Sookarryabd3c212023-10-11 17:04:04 +01006#include <fmt/format.h>
Teresa Charlinb2d3ec52022-04-12 22:07:09 +01007#include "RefGatherNdWorkload.hpp"
8
9#include "Gather.hpp"
10#include "Profiling.hpp"
11#include "RefWorkloadUtils.hpp"
12#include "backendsCommon/WorkloadUtils.hpp"
13
14namespace armnn
15{
16
17void RefGatherNdWorkload::Execute() const
18{
19 Execute(m_Data.m_Inputs, m_Data.m_Outputs);
20}
21
Matthew Sloyan2d213a72022-06-30 17:13:04 +010022void RefGatherNdWorkload::ExecuteAsync(ExecutionData& executionData)
Teresa Charlinb2d3ec52022-04-12 22:07:09 +010023{
Matthew Sloyan2d213a72022-06-30 17:13:04 +010024 WorkingMemDescriptor* workingMemDescriptor = static_cast<WorkingMemDescriptor*>(executionData.m_Data);
25 Execute(workingMemDescriptor->m_Inputs, workingMemDescriptor->m_Outputs);
Teresa Charlinb2d3ec52022-04-12 22:07:09 +010026}
27
28void RefGatherNdWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const
29{
Mike Kelly7cbe7812023-07-25 17:37:33 +010030 ARMNN_SCOPED_PROFILING_EVENT_REF_NAME_GUID("RefGatherNdWorkload_Execute");
Teresa Charlinb2d3ec52022-04-12 22:07:09 +010031
32 const TensorInfo& inputInfo0 = GetTensorInfo(inputs[0]);
33 const TensorInfo& inputInfo1 = GetTensorInfo(inputs[1]);
34 const TensorInfo& outputInfo = GetTensorInfo(outputs[0]);
35
36 std::unique_ptr<Decoder<float>> params_decoderPtr = MakeDecoder<float>(inputInfo0, inputs[0]->Map());
37
38 const int32_t* indicesDataPtr = reinterpret_cast<int32_t*>(inputs[1]->Map());
39 std::vector<int32_t> indices(indicesDataPtr, indicesDataPtr + inputInfo1.GetNumElements());
Ciara Sookarryabd3c212023-10-11 17:04:04 +010040 // Check for negative indices, it could not be checked in validate as we do not have access to the values there
41 for (unsigned int i = 0; i < inputInfo1.GetNumElements(); ++i)
42 {
43 if (indices[i] < 0)
44 {
45 throw InvalidArgumentException((fmt::format("GatherNd: indices[{}] < 0", i)));
46 }
47 }
Teresa Charlinb2d3ec52022-04-12 22:07:09 +010048
49 std::unique_ptr<Encoder<float>> output_encoderPtr = MakeEncoder<float>(outputInfo, outputs[0]->Map());
50
51 std::map<std::string, unsigned int> keyIndices = CalculateGatherNdKeyIndices(inputInfo0, inputInfo1);
52
53 /// Calculate flattened indices: flattenedIndices = indices * flattenedCoefficients
54 // Calculate the flattened coefficients to use in the multiplication
55 // to calculate the flattened indices needed by gather
56 TensorShape paramsShape = inputInfo0.GetShape();
57 std::vector<unsigned int> flattenedCoeff(keyIndices["ND"], 1);
58 for (unsigned int i = 1; i < keyIndices["ND"]; ++i)
59 {
60 flattenedCoeff[i-1] = paramsShape[i];
61 }
62 for (unsigned int i = keyIndices["ND"]-1; i > 0; --i)
63 {
64 flattenedCoeff[i-1] *= flattenedCoeff[i];
65 }
66
67 // Prepare the vector to store the output of the matrix multiplication,
68 // which will represent the flattened indices needed by gather
69 armnn::TensorInfo flattenedIndices_Info = inputInfo1;
70 flattenedIndices_Info.SetShape({ keyIndices["W"] });
71 std::vector<int32_t> flattenedIndices(flattenedIndices_Info.GetNumElements(), 0);
72
73 // Multiplication to calculate the flattened indices, which are the indices needed by gather.
74 for (unsigned int i = 0; i < keyIndices["W"]; ++i)
75 {
76 for (unsigned int j = 0; j < keyIndices["ND"]; ++j)
77 {
78 flattenedIndices[i] += indices[i * keyIndices["ND"] + j] * static_cast<int32_t>(flattenedCoeff[j]);
79 }
80 }
81
82 /// Call Gather with adequate shapes
83 // Reshape params into {K, C}
84 armnn::TensorInfo params_K_C_Info = inputInfo0;
85 params_K_C_Info.SetShape({ keyIndices["K"], keyIndices["C"] });
86
87 // Reshape indices into {N, W}
88 armnn::TensorInfo indices_N_W_Info = inputInfo1;
89 indices_N_W_Info.SetShape({ keyIndices["N"], keyIndices["W"] });
90
91 // Reshape output to have the shape given by gather {N, W, C}
92 // (the original outputInfo has the shape given by gatherNd)
93 armnn::TensorInfo outputGather_Info = outputInfo;
94 outputGather_Info.SetShape({ keyIndices["N"], keyIndices["W"], keyIndices["C"] });
95
96 // output_gather = gather(params_K_C, indices_N_W)
97 Gather(params_K_C_Info, indices_N_W_Info, outputGather_Info,
98 *params_decoderPtr, flattenedIndices.data(), *output_encoderPtr, 0);
99}
100
101} //namespace armnn