blob: 48039432a5ae93c2bba8cf31d03e981e08d95779 [file] [log] [blame]
narpra014951d842019-01-18 16:53:53 +00001//
Nikhil Raj369d8fc2022-11-24 13:12:36 +00002// Copyright © 2017,2022 Arm Ltd and Contributors. All rights reserved.
narpra014951d842019-01-18 16:53:53 +00003// SPDX-License-Identifier: MIT
4//
5
6#include "Gather.hpp"
7
Colm Donelan0c479742021-12-10 12:43:54 +00008#include <armnn/backends/WorkloadData.hpp>
Matthew Sloyan171214c2020-09-09 09:07:37 +01009#include <armnn/utility/NumericCast.hpp>
Aron Virginas-Tard4f0fea2019-04-09 14:08:06 +010010
narpra014951d842019-01-18 16:53:53 +000011namespace armnn
12{
13
narpra014951d842019-01-18 16:53:53 +000014void Gather(const TensorInfo& paramsInfo,
15 const TensorInfo& indicesInfo,
16 const TensorInfo& outputInfo,
Ellen Norris-Thompson6858d3f2019-06-21 15:50:00 +010017 Decoder<float>& params,
narpra014951d842019-01-18 16:53:53 +000018 const int32_t* indices,
Teresa Charlin52664732020-06-29 16:27:03 +010019 Encoder<float>& output,
Nikhil Raj369d8fc2022-11-24 13:12:36 +000020 const int32_t axis_int)
narpra014951d842019-01-18 16:53:53 +000021{
Jan Eilers8eb25602020-03-09 12:13:48 +000022 IgnoreUnused(outputInfo);
Nikhil Raj369d8fc2022-11-24 13:12:36 +000023
24 const int paramsRank = static_cast<int>(paramsInfo.GetNumDimensions());
25 ARMNN_ASSERT(-1 * paramsRank <= axis_int && axis_int < paramsRank);
26 const unsigned int axis = (axis_int < 0) ? static_cast<unsigned int>(paramsRank + axis_int)
27 : static_cast<unsigned int>(axis_int);
Teresa Charlin52664732020-06-29 16:27:03 +010028
narpra014951d842019-01-18 16:53:53 +000029 const TensorShape& paramsShape = paramsInfo.GetShape();
30
Nikhil Raj369d8fc2022-11-24 13:12:36 +000031 // Product of all dimensions to the left side of the axis
32 unsigned int paramsOuterProduct = 1;
33 for (unsigned int i = 0; i < axis; ++i)
narpra014951d842019-01-18 16:53:53 +000034 {
Nikhil Raj369d8fc2022-11-24 13:12:36 +000035 paramsOuterProduct *= paramsShape[i];
36 }
37 // Product of all dimensions to the right side of the axis
38 unsigned int paramsInnerProduct = 1;
39 for (unsigned int k = 1 + axis; k < paramsInfo.GetNumDimensions(); ++k)
40 {
41 paramsInnerProduct *= paramsShape[k];
narpra014951d842019-01-18 16:53:53 +000042 }
43
Nikhil Raj369d8fc2022-11-24 13:12:36 +000044 unsigned int offset = 0;
narpra014951d842019-01-18 16:53:53 +000045 unsigned int outIndex = 0;
Nikhil Raj369d8fc2022-11-24 13:12:36 +000046 for (unsigned int i = 0; i < paramsOuterProduct; ++i)
narpra014951d842019-01-18 16:53:53 +000047 {
Nikhil Raj369d8fc2022-11-24 13:12:36 +000048 for (unsigned int j = 0; j < indicesInfo.GetNumElements(); ++j)
narpra014951d842019-01-18 16:53:53 +000049 {
Nikhil Raj369d8fc2022-11-24 13:12:36 +000050 unsigned int index = armnn::numeric_cast<unsigned int>(indices[j]);
51 ARMNN_ASSERT(indices[j] >= 0 && index < paramsShape[axis]);
52
53 unsigned int startOffset = (paramsInnerProduct * index) + offset;
54 unsigned int endOffset = startOffset + paramsInnerProduct;
55
56 for (unsigned int k = startOffset; k < endOffset; ++k)
57 {
58 params[k];
59 float outputValue = params.Get();
60 output[outIndex];
61 output.Set(outputValue);
62 ++outIndex;
63 }
narpra014951d842019-01-18 16:53:53 +000064 }
Nikhil Raj369d8fc2022-11-24 13:12:36 +000065 offset += paramsShape[axis] * paramsInnerProduct;
narpra014951d842019-01-18 16:53:53 +000066 }
67
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010068 ARMNN_ASSERT(outIndex == outputInfo.GetNumElements());
narpra014951d842019-01-18 16:53:53 +000069}
70
Nikhil Raj369d8fc2022-11-24 13:12:36 +000071} //namespace armnn