blob: 03aa2458f5d637ee2ede40b722e5c685eafe2ca3 [file] [log] [blame]
narpra014951d842019-01-18 16:53:53 +00001//
Teresa Charlin52664732020-06-29 16:27:03 +01002// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
narpra014951d842019-01-18 16:53:53 +00003// SPDX-License-Identifier: MIT
4//
5
6#include "Gather.hpp"
7
8#include "RefWorkloadUtils.hpp"
9
10#include <backendsCommon/WorkloadData.hpp>
Jan Eilers8eb25602020-03-09 12:13:48 +000011#include <armnn/utility/IgnoreUnused.hpp>
Matthew Sloyan171214c2020-09-09 09:07:37 +010012#include <armnn/utility/NumericCast.hpp>
Aron Virginas-Tard4f0fea2019-04-09 14:08:06 +010013
narpra014951d842019-01-18 16:53:53 +000014namespace armnn
15{
16
narpra014951d842019-01-18 16:53:53 +000017void Gather(const TensorInfo& paramsInfo,
18 const TensorInfo& indicesInfo,
19 const TensorInfo& outputInfo,
Ellen Norris-Thompson6858d3f2019-06-21 15:50:00 +010020 Decoder<float>& params,
narpra014951d842019-01-18 16:53:53 +000021 const int32_t* indices,
Teresa Charlin52664732020-06-29 16:27:03 +010022 Encoder<float>& output,
23 const int32_t axis)
narpra014951d842019-01-18 16:53:53 +000024{
Jan Eilers8eb25602020-03-09 12:13:48 +000025 IgnoreUnused(outputInfo);
Teresa Charlin52664732020-06-29 16:27:03 +010026 IgnoreUnused(axis);
27
narpra014951d842019-01-18 16:53:53 +000028 const TensorShape& paramsShape = paramsInfo.GetShape();
29
30 unsigned int paramsProduct = 1;
31 for (unsigned int i = 1; i < paramsInfo.GetNumDimensions(); ++i)
32 {
33 paramsProduct = paramsProduct * paramsShape[i];
34 }
35
36 unsigned int outIndex = 0;
37 for (unsigned int i = 0; i < indicesInfo.GetNumElements(); ++i)
38 {
Matthew Sloyan171214c2020-09-09 09:07:37 +010039 unsigned int indx = armnn::numeric_cast<unsigned int>(indices[i]);
narpra014951d842019-01-18 16:53:53 +000040
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010041 ARMNN_ASSERT(indices[i] >= 0 && indx < paramsShape[0]);
narpra014951d842019-01-18 16:53:53 +000042
43 unsigned int startOffset = indx * paramsProduct;
44 unsigned int endOffset = startOffset + paramsProduct;
Ellen Norris-Thompson6858d3f2019-06-21 15:50:00 +010045
narpra014951d842019-01-18 16:53:53 +000046 for (unsigned int j = startOffset; j < endOffset; ++j)
47 {
Ellen Norris-Thompson6858d3f2019-06-21 15:50:00 +010048 params[j];
49 float outputValue = params.Get();
50 output[outIndex];
51 output.Set(outputValue);
narpra014951d842019-01-18 16:53:53 +000052 ++outIndex;
53 }
54 }
55
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010056 ARMNN_ASSERT(outIndex == outputInfo.GetNumElements());
narpra014951d842019-01-18 16:53:53 +000057}
58
narpra014951d842019-01-18 16:53:53 +000059} //namespace armnn