blob: 3e2190c81b1113d20f46bcd8f21ebb432eda9e6b [file] [log] [blame]
narpra014951d842019-01-18 16:53:53 +00001//
Teresa Charlin52664732020-06-29 16:27:03 +01002// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
narpra014951d842019-01-18 16:53:53 +00003// SPDX-License-Identifier: MIT
4//
5
6#include "Gather.hpp"
7
8#include "RefWorkloadUtils.hpp"
9
10#include <backendsCommon/WorkloadData.hpp>
Jan Eilers8eb25602020-03-09 12:13:48 +000011#include <armnn/utility/IgnoreUnused.hpp>
narpra014951d842019-01-18 16:53:53 +000012
Aron Virginas-Tard4f0fea2019-04-09 14:08:06 +010013#include <boost/numeric/conversion/cast.hpp>
14
narpra014951d842019-01-18 16:53:53 +000015namespace armnn
16{
17
narpra014951d842019-01-18 16:53:53 +000018void Gather(const TensorInfo& paramsInfo,
19 const TensorInfo& indicesInfo,
20 const TensorInfo& outputInfo,
Ellen Norris-Thompson6858d3f2019-06-21 15:50:00 +010021 Decoder<float>& params,
narpra014951d842019-01-18 16:53:53 +000022 const int32_t* indices,
Teresa Charlin52664732020-06-29 16:27:03 +010023 Encoder<float>& output,
24 const int32_t axis)
narpra014951d842019-01-18 16:53:53 +000025{
Jan Eilers8eb25602020-03-09 12:13:48 +000026 IgnoreUnused(outputInfo);
Teresa Charlin52664732020-06-29 16:27:03 +010027 IgnoreUnused(axis);
28
narpra014951d842019-01-18 16:53:53 +000029 const TensorShape& paramsShape = paramsInfo.GetShape();
30
31 unsigned int paramsProduct = 1;
32 for (unsigned int i = 1; i < paramsInfo.GetNumDimensions(); ++i)
33 {
34 paramsProduct = paramsProduct * paramsShape[i];
35 }
36
37 unsigned int outIndex = 0;
38 for (unsigned int i = 0; i < indicesInfo.GetNumElements(); ++i)
39 {
40 unsigned int indx = boost::numeric_cast<unsigned int>(indices[i]);
41
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010042 ARMNN_ASSERT(indices[i] >= 0 && indx < paramsShape[0]);
narpra014951d842019-01-18 16:53:53 +000043
44 unsigned int startOffset = indx * paramsProduct;
45 unsigned int endOffset = startOffset + paramsProduct;
Ellen Norris-Thompson6858d3f2019-06-21 15:50:00 +010046
narpra014951d842019-01-18 16:53:53 +000047 for (unsigned int j = startOffset; j < endOffset; ++j)
48 {
Ellen Norris-Thompson6858d3f2019-06-21 15:50:00 +010049 params[j];
50 float outputValue = params.Get();
51 output[outIndex];
52 output.Set(outputValue);
narpra014951d842019-01-18 16:53:53 +000053 ++outIndex;
54 }
55 }
56
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010057 ARMNN_ASSERT(outIndex == outputInfo.GetNumElements());
narpra014951d842019-01-18 16:53:53 +000058}
59
narpra014951d842019-01-18 16:53:53 +000060} //namespace armnn