blob: 534a063ed544eda44f6f12c34922edc4899df126 [file] [log] [blame]
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001//
2// Copyright © 2019 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "Slice.hpp"
7
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01008#include <armnn/utility/Assert.hpp>
Jan Eilers8eb25602020-03-09 12:13:48 +00009#include <armnn/utility/IgnoreUnused.hpp>
10
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +010011namespace armnn
12{
13
14void Slice(const TensorInfo& inputInfo,
15 const SliceDescriptor& descriptor,
16 const void* inputData,
17 void* outputData,
18 unsigned int dataTypeSize)
19{
20 const TensorShape& inputShape = inputInfo.GetShape();
21 const unsigned int numDims = inputShape.GetNumDimensions();
22
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +010023 constexpr unsigned int maxNumDims = 4;
David Monahan6a1d5062023-08-29 09:10:50 +010024 if (descriptor.m_Begin.size() != numDims)
25 {
26 std::stringstream msg;
27 msg << "Slice: Number of dimensions (" << numDims <<
28 ") does not match the Begin vector in the descriptor (" << descriptor.m_Begin.size() << ")";
29 throw InvalidArgumentException(msg.str());
30 }
31 if (descriptor.m_Size.size() != numDims)
32 {
33 std::stringstream msg;
34 msg << "Slice: Number of dimensions (" << numDims <<
35 ") does not match the Size vector in the descriptor (" << descriptor.m_Size.size() << ")";
36 throw InvalidArgumentException(msg.str());
37 }
38 if (numDims > maxNumDims)
39 {
40 std::stringstream msg;
41 msg << "Slice: Number of dimensions (" << numDims <<
42 ") is greater than the maximum supported (" << maxNumDims << ")";
43 throw InvalidArgumentException(msg.str());
44 }
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +010045
46 std::vector<unsigned int> paddedInput(4);
47 std::vector<unsigned int> paddedBegin(4);
48 std::vector<unsigned int> paddedSize (4);
49
50 const unsigned int numPaddingDims = maxNumDims - numDims;
51 for (unsigned int i = 0u; i < maxNumDims; ++i)
52 {
53 if (i < numPaddingDims)
54 {
55 paddedInput[i] = 1u;
56 paddedBegin[i] = 0u;
57 paddedSize[i] = 1u;
58 }
59 else
60 {
61 const unsigned int j = i - numPaddingDims;
62 paddedInput[i] = inputShape[j];
63 paddedBegin[i] = descriptor.m_Begin[j];
64 paddedSize[i] = descriptor.m_Size[j];
65 }
66 }
67
68 unsigned int dim0 = paddedInput[0];
69 unsigned int dim1 = paddedInput[1];
70 unsigned int dim2 = paddedInput[2];
71 unsigned int dim3 = paddedInput[3];
72
73 unsigned int begin0 = paddedBegin[0];
74 unsigned int begin1 = paddedBegin[1];
75 unsigned int begin2 = paddedBegin[2];
76 unsigned int begin3 = paddedBegin[3];
77
78 unsigned int size0 = paddedSize[0];
79 unsigned int size1 = paddedSize[1];
80 unsigned int size2 = paddedSize[2];
81 unsigned int size3 = paddedSize[3];
82
David Monahan6a1d5062023-08-29 09:10:50 +010083 if (begin0 + size0 > dim0)
84 {
85 std::stringstream msg;
86 msg << "Slice: begin0 + size0 (" << (begin0 + size0) <<
87 ") exceeds dim0 (" << dim0 << ")";
88 throw InvalidArgumentException(msg.str());
89 }
90 if (begin1 + size1 > dim1)
91 {
92 std::stringstream msg;
93 msg << "Slice: begin1 + size1 (" << (begin1 + size1) <<
94 ") exceeds dim2 (" << dim1 << ")";
95 throw InvalidArgumentException(msg.str());
96 }
97 if (begin2 + size2 > dim2)
98 {
99 std::stringstream msg;
100 msg << "Slice: begin2 + size2 (" << (begin2 + size2) <<
101 ") exceeds dim2 (" << dim2 << ")";
102 throw InvalidArgumentException(msg.str());
103 }
104 if (begin3 + size3 > dim3)
105 {
106 std::stringstream msg;
107 msg << "Slice: begin3 + size3 (" << (begin3 + size3) <<
108 ") exceeds dim3 (" << dim3 << ")";
109 throw InvalidArgumentException(msg.str());
110 }
111
112 if (inputData == nullptr)
113 {
114 throw armnn::NullPointerException("Slice: Null inputData pointer");
115 }
116 if (outputData == nullptr)
117 {
118 throw armnn::NullPointerException("Slice: Null outputData pointer");
119 }
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +0100120
121 const unsigned char* input = reinterpret_cast<const unsigned char*>(inputData);
122 unsigned char* output = reinterpret_cast<unsigned char*>(outputData);
123
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +0100124 for (unsigned int idx0 = begin0; idx0 < begin0 + size0; ++idx0)
125 {
126 for (unsigned int idx1 = begin1; idx1 < begin1 + size1; ++idx1)
127 {
128 for (unsigned int idx2 = begin2; idx2 < begin2 + size2; ++idx2)
129 {
130 for (unsigned int idx3 = begin3; idx3 < begin3 + size3; ++idx3)
131 {
132 const unsigned int inputOffset =
133 (((idx0 * dim1 + idx1) * dim2 + idx2) * dim3 + idx3) * dataTypeSize;
134
135 ::memcpy(output, input + inputOffset, dataTypeSize);
136 output += dataTypeSize;
137 }
138 }
139 }
140 }
141}
142
143} // namespace armnn