blob: 71903e421d20c12e63ce511c079d072c9d65f647 [file] [log] [blame]
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "StridedSlice.hpp"
7
8#include <boost/assert.hpp>
9#include <boost/numeric/conversion/cast.hpp>
10
11namespace armnn
12{
13
14void PadParams(StridedSliceDescriptor& p, unsigned int dimCount)
15{
16 BOOST_ASSERT_MSG(dimCount <= 4, "Expected input with at most 4 dimensions");
17
18 const unsigned int beginIndicesCount =
19 boost::numeric_cast<unsigned int>(p.m_Begin.size());
20
21 BOOST_ASSERT(dimCount >= beginIndicesCount);
22 const unsigned int padCount = dimCount - beginIndicesCount;
23
24 p.m_Begin.resize(dimCount);
25 p.m_End.resize(dimCount);
26 p.m_Stride.resize(dimCount);
27
28 for (unsigned int i = beginIndicesCount; i > 0; --i)
29 {
30 p.m_Stride[i + padCount - 1] = p.m_Stride[i - 1];
31 p.m_Begin[i + padCount - 1] = p.m_Begin[i - 1];
32 p.m_End[i + padCount - 1] = p.m_End[i - 1];
33 }
34
35 for (unsigned int i = 0; i < padCount; ++i)
36 {
37 p.m_Stride[i] = 1;
38 p.m_Begin[i] = 0;
39 p.m_End[i] = 0;
40 }
41
42 p.m_ShrinkAxisMask <<= padCount;
43 p.m_EllipsisMask <<= padCount;
44 p.m_NewAxisMask <<= padCount;
45 p.m_BeginMask <<= padCount;
46 p.m_EndMask <<= padCount;
47 p.m_BeginMask |= (1 << padCount) - 1;
48 p.m_EndMask |= (1 << padCount) - 1;
49}
50
51bool LoopCondition(int index, int stop, int stride)
52{
53 return stride > 0 ? index >= stop : index <= stop;
54}
55
56TensorShape ExtendShape(const TensorShape& inputShape,
57 unsigned int newNumDimensions)
58{
59 if (inputShape.GetNumDimensions() >= newNumDimensions)
60 {
61 return inputShape;
62 }
63
64 unsigned int newSizes[newNumDimensions];
65
66 unsigned int diff = newNumDimensions - inputShape.GetNumDimensions();
67
68 for (unsigned int i = 0; i < diff; i++)
69 {
70 newSizes[i] = 1;
71 }
72
73 for (unsigned int i = diff; i < newNumDimensions; i++)
74 {
75 newSizes[i] = inputShape[i - diff];
76 }
77
78 return TensorShape(newNumDimensions, newSizes);
79}
80
81template<typename T>
82void StridedSlice(const TensorInfo& inputInfo,
83 const TensorInfo& outputInfo,
84 const StridedSliceDescriptor& params,
85 const T* inputData,
86 T* outputData)
87{
88 const TensorShape inputShape =
89 ExtendShape(inputInfo.GetShape(), 4);
90
91 StridedSliceDescriptor paddedParams = params;
92
93 // Pad parameters to 4 dimensions
94 PadParams(paddedParams, 4);
95
96 const int start0 =
97 paddedParams.GetStartForAxis(inputShape, 0);
98 const int stop0 =
99 paddedParams.GetStopForAxis(inputShape, 0, start0);
100
101 const int start1 =
102 paddedParams.GetStartForAxis(inputShape, 1);
103 const int stop1 =
104 paddedParams.GetStopForAxis(inputShape, 1, start1);
105
106 const int start2 =
107 paddedParams.GetStartForAxis(inputShape, 2);
108 const int stop2 =
109 paddedParams.GetStopForAxis(inputShape, 2, start2);
110
111 const int start3 =
112 paddedParams.GetStartForAxis(inputShape, 3);
113 const int stop3 =
114 paddedParams.GetStopForAxis(inputShape, 3, start3);
115
116 T* outPtr = outputData;
117
118 for (int in0 = start0;
119 !LoopCondition(in0, stop0, paddedParams.m_Stride[0]);
120 in0 += paddedParams.m_Stride[0])
121 {
122 for (int in1 = start1;
123 !LoopCondition(in1, stop1, paddedParams.m_Stride[1]);
124 in1 += paddedParams.m_Stride[1])
125 {
126 for (int in2 = start2;
127 !LoopCondition(in2, stop2, paddedParams.m_Stride[2]);
128 in2 += paddedParams.m_Stride[2])
129 {
130 for (int in3 = start3;
131 !LoopCondition(in3, stop3, paddedParams.m_Stride[3]);
132 in3 += paddedParams.m_Stride[3])
133 {
134 int dim1 = boost::numeric_cast<int>(inputShape[1]);
135 int dim2 = boost::numeric_cast<int>(inputShape[2]);
136 int dim3 = boost::numeric_cast<int>(inputShape[3]);
137
138 int inputOffset = ((in0 * dim1 + in1) * dim2 + in2) * dim3 + in3;
139 *(outPtr++) = inputData[inputOffset];
140 }
141 }
142 }
143 }
144}
145
146template void StridedSlice<float>(const TensorInfo& inputInfo,
147 const TensorInfo& outputInfo,
148 const StridedSliceDescriptor& params,
149 const float* inputData,
150 float* outData);
151
152template void StridedSlice<uint8_t>(const TensorInfo& inputInfo,
153 const TensorInfo& outputInfo,
154 const StridedSliceDescriptor& params,
155 const uint8_t* inputData,
156 uint8_t* outData);
157
158} //namespace armnn