blob: 585fc81b001c92a7351e3482440bb79142de15f8 [file] [log] [blame]
narpra014951d842019-01-18 16:53:53 +00001//
Kevin May49f8d6a2023-06-01 16:42:05 +01002// Copyright © 2017,2022-2023 Arm Ltd and Contributors. All rights reserved.
narpra014951d842019-01-18 16:53:53 +00003// SPDX-License-Identifier: MIT
4//
5
6#include "Gather.hpp"
7
Colm Donelan0c479742021-12-10 12:43:54 +00008#include <armnn/backends/WorkloadData.hpp>
Kevin May49f8d6a2023-06-01 16:42:05 +01009
10#include <fmt/format.h>
Aron Virginas-Tard4f0fea2019-04-09 14:08:06 +010011
narpra014951d842019-01-18 16:53:53 +000012namespace armnn
13{
14
narpra014951d842019-01-18 16:53:53 +000015void Gather(const TensorInfo& paramsInfo,
16 const TensorInfo& indicesInfo,
17 const TensorInfo& outputInfo,
Ellen Norris-Thompson6858d3f2019-06-21 15:50:00 +010018 Decoder<float>& params,
narpra014951d842019-01-18 16:53:53 +000019 const int32_t* indices,
Teresa Charlin52664732020-06-29 16:27:03 +010020 Encoder<float>& output,
Nikhil Raj369d8fc2022-11-24 13:12:36 +000021 const int32_t axis_int)
narpra014951d842019-01-18 16:53:53 +000022{
Jan Eilers8eb25602020-03-09 12:13:48 +000023 IgnoreUnused(outputInfo);
Nikhil Raj369d8fc2022-11-24 13:12:36 +000024
25 const int paramsRank = static_cast<int>(paramsInfo.GetNumDimensions());
Kevin May49f8d6a2023-06-01 16:42:05 +010026 if((axis_int < -1 * paramsRank) || (paramsRank <= axis_int))
27 {
28 throw InvalidArgumentException((fmt::format("Gather: Axis {} is not within [-{}, {}) range",
29 axis_int, paramsRank, paramsRank)));
30 }
Nikhil Raj369d8fc2022-11-24 13:12:36 +000031 const unsigned int axis = (axis_int < 0) ? static_cast<unsigned int>(paramsRank + axis_int)
32 : static_cast<unsigned int>(axis_int);
Teresa Charlin52664732020-06-29 16:27:03 +010033
narpra014951d842019-01-18 16:53:53 +000034 const TensorShape& paramsShape = paramsInfo.GetShape();
35
Nikhil Raj369d8fc2022-11-24 13:12:36 +000036 // Product of all dimensions to the left side of the axis
37 unsigned int paramsOuterProduct = 1;
38 for (unsigned int i = 0; i < axis; ++i)
narpra014951d842019-01-18 16:53:53 +000039 {
Nikhil Raj369d8fc2022-11-24 13:12:36 +000040 paramsOuterProduct *= paramsShape[i];
41 }
42 // Product of all dimensions to the right side of the axis
43 unsigned int paramsInnerProduct = 1;
44 for (unsigned int k = 1 + axis; k < paramsInfo.GetNumDimensions(); ++k)
45 {
46 paramsInnerProduct *= paramsShape[k];
narpra014951d842019-01-18 16:53:53 +000047 }
48
Nikhil Raj369d8fc2022-11-24 13:12:36 +000049 unsigned int offset = 0;
narpra014951d842019-01-18 16:53:53 +000050 unsigned int outIndex = 0;
Nikhil Raj369d8fc2022-11-24 13:12:36 +000051 for (unsigned int i = 0; i < paramsOuterProduct; ++i)
narpra014951d842019-01-18 16:53:53 +000052 {
Nikhil Raj369d8fc2022-11-24 13:12:36 +000053 for (unsigned int j = 0; j < indicesInfo.GetNumElements(); ++j)
narpra014951d842019-01-18 16:53:53 +000054 {
Kevin May49f8d6a2023-06-01 16:42:05 +010055 unsigned int index =
56 (indices[j] < 0) ? static_cast<unsigned int>(static_cast<int>(paramsShape[axis]) + indices[j])
57 : static_cast<unsigned int>(indices[j]);
58
59 if (index >= paramsShape[axis])
60 {
61 throw InvalidArgumentException((fmt::format("Gather: index >= paramsShape[axis]: {} >= {}",
62 index, paramsShape[axis] )));
63 }
Nikhil Raj369d8fc2022-11-24 13:12:36 +000064
65 unsigned int startOffset = (paramsInnerProduct * index) + offset;
66 unsigned int endOffset = startOffset + paramsInnerProduct;
67
68 for (unsigned int k = startOffset; k < endOffset; ++k)
69 {
70 params[k];
71 float outputValue = params.Get();
72 output[outIndex];
73 output.Set(outputValue);
74 ++outIndex;
75 }
narpra014951d842019-01-18 16:53:53 +000076 }
Nikhil Raj369d8fc2022-11-24 13:12:36 +000077 offset += paramsShape[axis] * paramsInnerProduct;
narpra014951d842019-01-18 16:53:53 +000078 }
79
Kevin May49f8d6a2023-06-01 16:42:05 +010080 if (outIndex != outputInfo.GetNumElements())
81 {
82 throw InvalidArgumentException((fmt::format("Gather: Invalid outIndex {} ", outIndex)));
83 }
narpra014951d842019-01-18 16:53:53 +000084}
85
Nikhil Raj369d8fc2022-11-24 13:12:36 +000086} //namespace armnn