blob: 4cf3a142a079f8dfd61bc5308934049fee7fa20b [file] [log] [blame]
narpra014951d842019-01-18 16:53:53 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "Gather.hpp"
7
8#include "RefWorkloadUtils.hpp"
9
10#include <backendsCommon/WorkloadData.hpp>
Jan Eilers8eb25602020-03-09 12:13:48 +000011#include <armnn/utility/IgnoreUnused.hpp>
narpra014951d842019-01-18 16:53:53 +000012
Aron Virginas-Tard4f0fea2019-04-09 14:08:06 +010013#include <boost/numeric/conversion/cast.hpp>
14
narpra014951d842019-01-18 16:53:53 +000015namespace armnn
16{
17
narpra014951d842019-01-18 16:53:53 +000018void Gather(const TensorInfo& paramsInfo,
19 const TensorInfo& indicesInfo,
20 const TensorInfo& outputInfo,
Ellen Norris-Thompson6858d3f2019-06-21 15:50:00 +010021 Decoder<float>& params,
narpra014951d842019-01-18 16:53:53 +000022 const int32_t* indices,
Ellen Norris-Thompson6858d3f2019-06-21 15:50:00 +010023 Encoder<float>& output)
narpra014951d842019-01-18 16:53:53 +000024{
Jan Eilers8eb25602020-03-09 12:13:48 +000025 IgnoreUnused(outputInfo);
narpra014951d842019-01-18 16:53:53 +000026 const TensorShape& paramsShape = paramsInfo.GetShape();
27
28 unsigned int paramsProduct = 1;
29 for (unsigned int i = 1; i < paramsInfo.GetNumDimensions(); ++i)
30 {
31 paramsProduct = paramsProduct * paramsShape[i];
32 }
33
34 unsigned int outIndex = 0;
35 for (unsigned int i = 0; i < indicesInfo.GetNumElements(); ++i)
36 {
37 unsigned int indx = boost::numeric_cast<unsigned int>(indices[i]);
38
39 BOOST_ASSERT(indices[i] >= 0 && indx < paramsShape[0]);
40
41 unsigned int startOffset = indx * paramsProduct;
42 unsigned int endOffset = startOffset + paramsProduct;
Ellen Norris-Thompson6858d3f2019-06-21 15:50:00 +010043
narpra014951d842019-01-18 16:53:53 +000044 for (unsigned int j = startOffset; j < endOffset; ++j)
45 {
Ellen Norris-Thompson6858d3f2019-06-21 15:50:00 +010046 params[j];
47 float outputValue = params.Get();
48 output[outIndex];
49 output.Set(outputValue);
narpra014951d842019-01-18 16:53:53 +000050 ++outIndex;
51 }
52 }
53
54 BOOST_ASSERT(outIndex == outputInfo.GetNumElements());
55}
56
narpra014951d842019-01-18 16:53:53 +000057} //namespace armnn