blob: 45491c7f525a0c555050b506da575a272deec90a [file] [log] [blame]
narpra014951d842019-01-18 16:53:53 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "Gather.hpp"
7
8#include "RefWorkloadUtils.hpp"
9
10#include <backendsCommon/WorkloadData.hpp>
11
Aron Virginas-Tard4f0fea2019-04-09 14:08:06 +010012#include <boost/numeric/conversion/cast.hpp>
13
narpra014951d842019-01-18 16:53:53 +000014namespace armnn
15{
16
17template <typename T>
18void Gather(const TensorInfo& paramsInfo,
19 const TensorInfo& indicesInfo,
20 const TensorInfo& outputInfo,
21 const T* params,
22 const int32_t* indices,
23 T* output)
24{
25 const TensorShape& paramsShape = paramsInfo.GetShape();
26
27 unsigned int paramsProduct = 1;
28 for (unsigned int i = 1; i < paramsInfo.GetNumDimensions(); ++i)
29 {
30 paramsProduct = paramsProduct * paramsShape[i];
31 }
32
33 unsigned int outIndex = 0;
34 for (unsigned int i = 0; i < indicesInfo.GetNumElements(); ++i)
35 {
36 unsigned int indx = boost::numeric_cast<unsigned int>(indices[i]);
37
38 BOOST_ASSERT(indices[i] >= 0 && indx < paramsShape[0]);
39
40 unsigned int startOffset = indx * paramsProduct;
41 unsigned int endOffset = startOffset + paramsProduct;
42 for (unsigned int j = startOffset; j < endOffset; ++j)
43 {
44 output[outIndex] = params[j];
45 ++outIndex;
46 }
47 }
48
49 BOOST_ASSERT(outIndex == outputInfo.GetNumElements());
50}
51
52template void Gather<float>(const TensorInfo& paramsInfo,
53 const TensorInfo& indicesInfo,
54 const TensorInfo& outputInfo,
55 const float* params,
56 const int32_t* indices,
57 float* output);
58
59template void Gather<uint8_t>(const TensorInfo& paramsInfo,
60 const TensorInfo& indicesInfo,
61 const TensorInfo& outputInfo,
62 const uint8_t* params,
63 const int32_t* indices,
64 uint8_t* output);
65
66} //namespace armnn