blob: 68600c9a9516b5e49d23ed26ab40737b8e215653 [file] [log] [blame]
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "StridedSlice.hpp"
7
Matteo Martincighe851b3d2019-05-28 14:31:20 +01008#include <ResolveType.hpp>
9
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010010#include <armnn/utility/Assert.hpp>
Matthew Sloyan171214c2020-09-09 09:07:37 +010011#include <armnn/utility/NumericCast.hpp>
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +000012
Matteo Martincighe851b3d2019-05-28 14:31:20 +010013#include <cstring>
14
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +000015namespace armnn
16{
17
Matteo Martincighe851b3d2019-05-28 14:31:20 +010018namespace
19{
20
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +000021void PadParams(StridedSliceDescriptor& p, unsigned int dimCount)
22{
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010023 ARMNN_ASSERT_MSG(dimCount <= 4, "Expected input with at most 4 dimensions");
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +000024
25 const unsigned int beginIndicesCount =
Matthew Sloyan171214c2020-09-09 09:07:37 +010026 armnn::numeric_cast<unsigned int>(p.m_Begin.size());
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +000027
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010028 ARMNN_ASSERT(dimCount >= beginIndicesCount);
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +000029 const unsigned int padCount = dimCount - beginIndicesCount;
30
31 p.m_Begin.resize(dimCount);
32 p.m_End.resize(dimCount);
33 p.m_Stride.resize(dimCount);
34
35 for (unsigned int i = beginIndicesCount; i > 0; --i)
36 {
37 p.m_Stride[i + padCount - 1] = p.m_Stride[i - 1];
38 p.m_Begin[i + padCount - 1] = p.m_Begin[i - 1];
39 p.m_End[i + padCount - 1] = p.m_End[i - 1];
40 }
41
42 for (unsigned int i = 0; i < padCount; ++i)
43 {
44 p.m_Stride[i] = 1;
45 p.m_Begin[i] = 0;
46 p.m_End[i] = 0;
47 }
48
49 p.m_ShrinkAxisMask <<= padCount;
50 p.m_EllipsisMask <<= padCount;
51 p.m_NewAxisMask <<= padCount;
52 p.m_BeginMask <<= padCount;
53 p.m_EndMask <<= padCount;
54 p.m_BeginMask |= (1 << padCount) - 1;
55 p.m_EndMask |= (1 << padCount) - 1;
56}
57
58bool LoopCondition(int index, int stop, int stride)
59{
60 return stride > 0 ? index >= stop : index <= stop;
61}
62
63TensorShape ExtendShape(const TensorShape& inputShape,
64 unsigned int newNumDimensions)
65{
66 if (inputShape.GetNumDimensions() >= newNumDimensions)
67 {
68 return inputShape;
69 }
70
Rob Hughes9e10c2b2019-07-23 15:37:19 +010071 std::vector<unsigned int> newSizes(newNumDimensions, 0);
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +000072
73 unsigned int diff = newNumDimensions - inputShape.GetNumDimensions();
74
75 for (unsigned int i = 0; i < diff; i++)
76 {
77 newSizes[i] = 1;
78 }
79
80 for (unsigned int i = diff; i < newNumDimensions; i++)
81 {
82 newSizes[i] = inputShape[i - diff];
83 }
84
Rob Hughes9e10c2b2019-07-23 15:37:19 +010085 return TensorShape(newNumDimensions, newSizes.data());
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +000086}
87
Matteo Martincighe851b3d2019-05-28 14:31:20 +010088} // Anonymous namespace
89
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +000090void StridedSlice(const TensorInfo& inputInfo,
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +000091 const StridedSliceDescriptor& params,
Matteo Martincighe851b3d2019-05-28 14:31:20 +010092 const void* inputData,
93 void* outputData,
94 unsigned int dataTypeSize)
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +000095{
David Monahan6a1d5062023-08-29 09:10:50 +010096 if (inputData == nullptr)
97 {
98 throw armnn::InvalidArgumentException("Slice: Null inputData pointer");
99 }
100 if (outputData == nullptr)
101 {
102 throw armnn::InvalidArgumentException("Slice: Null outputData pointer");
103 }
104
Matteo Martincighe851b3d2019-05-28 14:31:20 +0100105 const unsigned char* input = reinterpret_cast<const unsigned char*>(inputData);
106 unsigned char* output = reinterpret_cast<unsigned char*>(outputData);
107
108 const TensorShape inputShape = ExtendShape(inputInfo.GetShape(), 4);
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +0000109
110 StridedSliceDescriptor paddedParams = params;
111
112 // Pad parameters to 4 dimensions
113 PadParams(paddedParams, 4);
114
Matteo Martincighe851b3d2019-05-28 14:31:20 +0100115 const int start0 = paddedParams.GetStartForAxis(inputShape, 0);
116 const int stop0 = paddedParams.GetStopForAxis (inputShape, 0, start0);
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +0000117
Matteo Martincighe851b3d2019-05-28 14:31:20 +0100118 const int start1 = paddedParams.GetStartForAxis(inputShape, 1);
119 const int stop1 = paddedParams.GetStopForAxis (inputShape, 1, start1);
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +0000120
Matteo Martincighe851b3d2019-05-28 14:31:20 +0100121 const int start2 = paddedParams.GetStartForAxis(inputShape, 2);
122 const int stop2 = paddedParams.GetStopForAxis (inputShape, 2, start2);
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +0000123
Matteo Martincighe851b3d2019-05-28 14:31:20 +0100124 const int start3 = paddedParams.GetStartForAxis(inputShape, 3);
125 const int stop3 = paddedParams.GetStopForAxis (inputShape, 3, start3);
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +0000126
Matthew Sloyan171214c2020-09-09 09:07:37 +0100127 const int step = armnn::numeric_cast<int>(dataTypeSize);
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +0000128
129 for (int in0 = start0;
130 !LoopCondition(in0, stop0, paddedParams.m_Stride[0]);
131 in0 += paddedParams.m_Stride[0])
132 {
133 for (int in1 = start1;
134 !LoopCondition(in1, stop1, paddedParams.m_Stride[1]);
135 in1 += paddedParams.m_Stride[1])
136 {
137 for (int in2 = start2;
138 !LoopCondition(in2, stop2, paddedParams.m_Stride[2]);
139 in2 += paddedParams.m_Stride[2])
140 {
141 for (int in3 = start3;
142 !LoopCondition(in3, stop3, paddedParams.m_Stride[3]);
143 in3 += paddedParams.m_Stride[3])
144 {
Matthew Sloyan171214c2020-09-09 09:07:37 +0100145 int dim1 = armnn::numeric_cast<int>(inputShape[1]);
146 int dim2 = armnn::numeric_cast<int>(inputShape[2]);
147 int dim3 = armnn::numeric_cast<int>(inputShape[3]);
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +0000148
Matteo Martincighe851b3d2019-05-28 14:31:20 +0100149 int inputOffset = (((in0 * dim1 + in1) * dim2 + in2) * dim3 + in3) * step;
150 ::memcpy(output, input + inputOffset, dataTypeSize);
151 output += step;
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +0000152 }
153 }
154 }
155 }
156}
157
Matteo Martincighe851b3d2019-05-28 14:31:20 +0100158} // namespace armnn