blob: b00b049ff65814c39fef630826c9bd9f833a51f7 [file] [log] [blame]
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "StridedSlice.hpp"
7
Matteo Martincighe851b3d2019-05-28 14:31:20 +01008#include <ResolveType.hpp>
9
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010010#include <armnn/utility/Assert.hpp>
11
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +000012#include <boost/numeric/conversion/cast.hpp>
13
Matteo Martincighe851b3d2019-05-28 14:31:20 +010014#include <cstring>
15
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +000016namespace armnn
17{
18
Matteo Martincighe851b3d2019-05-28 14:31:20 +010019namespace
20{
21
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +000022void PadParams(StridedSliceDescriptor& p, unsigned int dimCount)
23{
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010024 ARMNN_ASSERT_MSG(dimCount <= 4, "Expected input with at most 4 dimensions");
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +000025
26 const unsigned int beginIndicesCount =
27 boost::numeric_cast<unsigned int>(p.m_Begin.size());
28
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010029 ARMNN_ASSERT(dimCount >= beginIndicesCount);
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +000030 const unsigned int padCount = dimCount - beginIndicesCount;
31
32 p.m_Begin.resize(dimCount);
33 p.m_End.resize(dimCount);
34 p.m_Stride.resize(dimCount);
35
36 for (unsigned int i = beginIndicesCount; i > 0; --i)
37 {
38 p.m_Stride[i + padCount - 1] = p.m_Stride[i - 1];
39 p.m_Begin[i + padCount - 1] = p.m_Begin[i - 1];
40 p.m_End[i + padCount - 1] = p.m_End[i - 1];
41 }
42
43 for (unsigned int i = 0; i < padCount; ++i)
44 {
45 p.m_Stride[i] = 1;
46 p.m_Begin[i] = 0;
47 p.m_End[i] = 0;
48 }
49
50 p.m_ShrinkAxisMask <<= padCount;
51 p.m_EllipsisMask <<= padCount;
52 p.m_NewAxisMask <<= padCount;
53 p.m_BeginMask <<= padCount;
54 p.m_EndMask <<= padCount;
55 p.m_BeginMask |= (1 << padCount) - 1;
56 p.m_EndMask |= (1 << padCount) - 1;
57}
58
59bool LoopCondition(int index, int stop, int stride)
60{
61 return stride > 0 ? index >= stop : index <= stop;
62}
63
64TensorShape ExtendShape(const TensorShape& inputShape,
65 unsigned int newNumDimensions)
66{
67 if (inputShape.GetNumDimensions() >= newNumDimensions)
68 {
69 return inputShape;
70 }
71
Rob Hughes9e10c2b2019-07-23 15:37:19 +010072 std::vector<unsigned int> newSizes(newNumDimensions, 0);
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +000073
74 unsigned int diff = newNumDimensions - inputShape.GetNumDimensions();
75
76 for (unsigned int i = 0; i < diff; i++)
77 {
78 newSizes[i] = 1;
79 }
80
81 for (unsigned int i = diff; i < newNumDimensions; i++)
82 {
83 newSizes[i] = inputShape[i - diff];
84 }
85
Rob Hughes9e10c2b2019-07-23 15:37:19 +010086 return TensorShape(newNumDimensions, newSizes.data());
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +000087}
88
Matteo Martincighe851b3d2019-05-28 14:31:20 +010089} // Anonymous namespace
90
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +000091void StridedSlice(const TensorInfo& inputInfo,
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +000092 const StridedSliceDescriptor& params,
Matteo Martincighe851b3d2019-05-28 14:31:20 +010093 const void* inputData,
94 void* outputData,
95 unsigned int dataTypeSize)
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +000096{
Matteo Martincighe851b3d2019-05-28 14:31:20 +010097 const unsigned char* input = reinterpret_cast<const unsigned char*>(inputData);
98 unsigned char* output = reinterpret_cast<unsigned char*>(outputData);
99
100 const TensorShape inputShape = ExtendShape(inputInfo.GetShape(), 4);
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +0000101
102 StridedSliceDescriptor paddedParams = params;
103
104 // Pad parameters to 4 dimensions
105 PadParams(paddedParams, 4);
106
Matteo Martincighe851b3d2019-05-28 14:31:20 +0100107 const int start0 = paddedParams.GetStartForAxis(inputShape, 0);
108 const int stop0 = paddedParams.GetStopForAxis (inputShape, 0, start0);
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +0000109
Matteo Martincighe851b3d2019-05-28 14:31:20 +0100110 const int start1 = paddedParams.GetStartForAxis(inputShape, 1);
111 const int stop1 = paddedParams.GetStopForAxis (inputShape, 1, start1);
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +0000112
Matteo Martincighe851b3d2019-05-28 14:31:20 +0100113 const int start2 = paddedParams.GetStartForAxis(inputShape, 2);
114 const int stop2 = paddedParams.GetStopForAxis (inputShape, 2, start2);
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +0000115
Matteo Martincighe851b3d2019-05-28 14:31:20 +0100116 const int start3 = paddedParams.GetStartForAxis(inputShape, 3);
117 const int stop3 = paddedParams.GetStopForAxis (inputShape, 3, start3);
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +0000118
Matteo Martincighe851b3d2019-05-28 14:31:20 +0100119 const int step = boost::numeric_cast<int>(dataTypeSize);
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +0000120
121 for (int in0 = start0;
122 !LoopCondition(in0, stop0, paddedParams.m_Stride[0]);
123 in0 += paddedParams.m_Stride[0])
124 {
125 for (int in1 = start1;
126 !LoopCondition(in1, stop1, paddedParams.m_Stride[1]);
127 in1 += paddedParams.m_Stride[1])
128 {
129 for (int in2 = start2;
130 !LoopCondition(in2, stop2, paddedParams.m_Stride[2]);
131 in2 += paddedParams.m_Stride[2])
132 {
133 for (int in3 = start3;
134 !LoopCondition(in3, stop3, paddedParams.m_Stride[3]);
135 in3 += paddedParams.m_Stride[3])
136 {
137 int dim1 = boost::numeric_cast<int>(inputShape[1]);
138 int dim2 = boost::numeric_cast<int>(inputShape[2]);
139 int dim3 = boost::numeric_cast<int>(inputShape[3]);
140
Matteo Martincighe851b3d2019-05-28 14:31:20 +0100141 int inputOffset = (((in0 * dim1 + in1) * dim2 + in2) * dim3 + in3) * step;
142 ::memcpy(output, input + inputOffset, dataTypeSize);
143 output += step;
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +0000144 }
145 }
146 }
147 }
148}
149
Matteo Martincighe851b3d2019-05-28 14:31:20 +0100150} // namespace armnn