blob: 896f9050f57a2032c40e6aa2df42d0d7becbeb21 [file] [log] [blame]
Tianle Cheng988354d2023-06-28 13:20:47 +01001//
2// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "ReverseV2Impl.hpp"
7
8#include <armnn/backends/WorkloadData.hpp>
9#include <armnn/Logging.hpp>
10#include <armnnUtils/Permute.hpp>
11
12namespace armnn
13{
14
15// Get multi-dimensional index for input tensor
16std::vector<unsigned int> ReverseGetMultIdx(const unsigned int idx,
17 unsigned int inputRank,
18 std::vector<unsigned int>& elementNumInner)
19{
20 std::vector<unsigned int> indexList(inputRank);
21
22 unsigned int mIdx = idx;
23
24 for (unsigned int iDim = 0; iDim < inputRank; ++iDim)
25 {
26 indexList[iDim] = static_cast<unsigned int>(mIdx / elementNumInner[iDim]);
27 mIdx %= elementNumInner[iDim];
28 }
29
30 return indexList;
31}
32
33// Get flattened index for output encoder
34unsigned int ReverseGetFlatIdx(const std::vector<unsigned int>& idxList,
35 unsigned int inputRank,
36 std::vector<unsigned int>& elementNumInner)
37{
38 unsigned int idx = 0;
39
40 for (unsigned int iDim = 0; iDim < inputRank; ++iDim)
41 {
42 idx += idxList[iDim] * elementNumInner[iDim];
43 }
44
45 return idx;
46}
47
48// Relocate the coordinate to the reversed tensor
49unsigned int ReverseRelocateIdx(unsigned int idx,
50 unsigned int inputRank,
51 std::vector<bool>& axisFlag,
52 std::vector<unsigned int>& dimSize,
53 std::vector<unsigned int>& elementNumInner)
54{
55 // Get the multidimensional index list for input
56 auto inputIdxList = ReverseGetMultIdx(idx, inputRank, elementNumInner);
57
58 std::vector<unsigned int> outputIdxList(inputRank);
59
60 // Relocate the input index to the output one
61 for (unsigned int iDim = 0; iDim < inputRank; ++iDim)
62 {
63 if (axisFlag[iDim])
64 {
65 outputIdxList[iDim] = dimSize[iDim] - inputIdxList[iDim] - 1;
66 }
67 else
68 {
69 outputIdxList[iDim] = inputIdxList[iDim];
70 }
71 }
72
73 // Get the 1-dimensional flattened index for output
74 unsigned int outputIdx = ReverseGetFlatIdx(outputIdxList, inputRank, elementNumInner);
75 return outputIdx;
76}
77
Tracy Narinebb8d7592023-07-13 16:50:54 +010078void ReverseV2(const TensorInfo& inputInfo,
79 const TensorInfo& axisInfo,
Tianle Cheng988354d2023-06-28 13:20:47 +010080 Decoder<float>& inputDecoder,
Tracy Narinebb8d7592023-07-13 16:50:54 +010081 Decoder<int>& axisDecoder,
Tianle Cheng988354d2023-06-28 13:20:47 +010082 Encoder<float>& outputEncoder)
83{
Tracy Narinebb8d7592023-07-13 16:50:54 +010084 unsigned int axesRank = static_cast<unsigned int>(axisInfo.GetNumElements());
85
Tianle Cheng988354d2023-06-28 13:20:47 +010086 // Empty axis and empty tensor case: copy input to output
Tracy Narinebb8d7592023-07-13 16:50:54 +010087 if ((axesRank == 0) || inputInfo.GetNumElements() == 0)
Tianle Cheng988354d2023-06-28 13:20:47 +010088 {
89 for (unsigned idx = 0; idx < inputInfo.GetNumElements(); idx++)
90 {
91 float inputValue = inputDecoder.Get();
92 inputDecoder += 1;
93 outputEncoder.Set(inputValue);
94 outputEncoder += 1;
95 }
96 return;
97 }
98
99 unsigned int inputRank = static_cast<unsigned int>(inputInfo.GetNumDimensions());
100
Tracy Narinebb8d7592023-07-13 16:50:54 +0100101 std::vector<bool> axisFlag(inputRank, false);
102 std::vector<unsigned int> dimSize(inputRank, 0);
103 std::vector<int32_t> axis(axesRank, 0);
104
105 // Decode the axis information
106 for (unsigned int i=0; i < axesRank; i++)
107 {
108 axis[i] = axisDecoder.Get();
109 axisDecoder += 1;
110 }
Tianle Cheng988354d2023-06-28 13:20:47 +0100111
112 // Make sure the axes are positive
Tracy Narinebb8d7592023-07-13 16:50:54 +0100113 for (int32_t axisElement: axis)
Tianle Cheng988354d2023-06-28 13:20:47 +0100114 {
115 axisElement = axisElement < 0 ? axisElement + static_cast<int32_t>(inputRank) : axisElement;
116 axisFlag[static_cast<uint32_t>(axisElement)] = true;
117 }
118
119 const TensorShape &inputShape = inputInfo.GetShape();
120
121 unsigned int elementNum = inputInfo.GetNumElements();
122 unsigned int baseDimSize = 1;
123
124 std::vector<unsigned int> elementNumInner;
125
126 // Get the number of element within the specific dimension
127 for (unsigned int iDim = 0; iDim < inputRank; ++iDim) {
128 dimSize[iDim] = inputShape[iDim];
129 baseDimSize *= dimSize[iDim];
130 elementNumInner.push_back(static_cast<unsigned int>(elementNum / baseDimSize));
131 }
132
133 // Iterate through all elements
134 for (unsigned int idx = 0; idx < elementNum; ++idx)
135 {
136 float inputValue = inputDecoder.Get();
137 inputDecoder += 1;
138 auto outputIdx = ReverseRelocateIdx(idx, inputRank, axisFlag, dimSize, elementNumInner);
139 outputEncoder[outputIdx];
140 outputEncoder.Set(inputValue);
141 }
142}
143
144} // namespace armnn