blob: b195003e04f84f131f765020a5cf47ce66b051bd [file] [log] [blame]
narpra014951d842019-01-18 16:53:53 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "Gather.hpp"
7
8#include "RefWorkloadUtils.hpp"
9
10#include <backendsCommon/WorkloadData.hpp>
11
12namespace armnn
13{
14
15template <typename T>
16void Gather(const TensorInfo& paramsInfo,
17 const TensorInfo& indicesInfo,
18 const TensorInfo& outputInfo,
19 const T* params,
20 const int32_t* indices,
21 T* output)
22{
23 const TensorShape& paramsShape = paramsInfo.GetShape();
24
25 unsigned int paramsProduct = 1;
26 for (unsigned int i = 1; i < paramsInfo.GetNumDimensions(); ++i)
27 {
28 paramsProduct = paramsProduct * paramsShape[i];
29 }
30
31 unsigned int outIndex = 0;
32 for (unsigned int i = 0; i < indicesInfo.GetNumElements(); ++i)
33 {
34 unsigned int indx = boost::numeric_cast<unsigned int>(indices[i]);
35
36 BOOST_ASSERT(indices[i] >= 0 && indx < paramsShape[0]);
37
38 unsigned int startOffset = indx * paramsProduct;
39 unsigned int endOffset = startOffset + paramsProduct;
40 for (unsigned int j = startOffset; j < endOffset; ++j)
41 {
42 output[outIndex] = params[j];
43 ++outIndex;
44 }
45 }
46
47 BOOST_ASSERT(outIndex == outputInfo.GetNumElements());
48}
49
50template void Gather<float>(const TensorInfo& paramsInfo,
51 const TensorInfo& indicesInfo,
52 const TensorInfo& outputInfo,
53 const float* params,
54 const int32_t* indices,
55 float* output);
56
57template void Gather<uint8_t>(const TensorInfo& paramsInfo,
58 const TensorInfo& indicesInfo,
59 const TensorInfo& outputInfo,
60 const uint8_t* params,
61 const int32_t* indices,
62 uint8_t* output);
63
64} //namespace armnn