blob: a8e4a1cb81a0458822ed66bdb6ab8c433ac57d5b [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002// Copyright © 2020 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include <armnn/Tensor.hpp>
7
8#include <armnnUtils/Transpose.hpp>
9
10#include "Half.hpp"
11
Mike Kellyc9ea45a2020-02-28 18:11:58 +000012#include <cstring>
13
14namespace
15{
16
17class TransposeLoop
18{
19public:
20 using size_type = unsigned int;
21
22 TransposeLoop(const armnn::TensorShape& srcShape, const armnn::PermutationVector& mappings)
23 : m_SrcShape(srcShape)
24 {
David Monahan6a1d5062023-08-29 09:10:50 +010025 if (srcShape.GetNumDimensions() != mappings.GetSize())
26 {
27 std::stringstream msg;
28 msg << "Transpose: Number of shape dimensions (" << srcShape.GetNumDimensions() <<
29 ") does not match the size of the mappings (" << mappings.GetSize() << ")";
30 throw armnn::InvalidArgumentException(msg.str());
31 }
Mike Kellyc9ea45a2020-02-28 18:11:58 +000032
33 const size_type numDims = srcShape.GetNumDimensions();
34
35 size_type srcStride = 1U;
36 size_type dstStride = 1U;
37
38 for (size_type i = numDims - 1U, k = 0U; k < numDims; ++k, --i)
39 {
40 m_SrcStrides[i] = srcStride;
41 m_DstStrides[mappings[i]] = dstStride;
42
43 srcStride *= srcShape[i];
44 dstStride *= srcShape[mappings[i]];
45 }
46 }
47
48 void Unroll(const void* srcData, void* dstData, size_t dataTypeSize)
49 {
David Monahan6a1d5062023-08-29 09:10:50 +010050 if (srcData == nullptr)
51 {
52 throw armnn::Exception("Transpose: Source Data pointer is null");
53 }
54 if (dstData == nullptr)
55 {
56 throw armnn::Exception("Transpose: Destination Data pointer is null");
57 }
58 if (dataTypeSize == 0)
59 {
60 throw armnn::Exception("Transpose: dataTypeSize is zero");
61 }
Mike Kellyc9ea45a2020-02-28 18:11:58 +000062
63 const unsigned char* srcDataPtr = reinterpret_cast<const unsigned char*>(srcData);
64 unsigned char* dstDataPtr = reinterpret_cast<unsigned char*>(dstData);
65
66 const unsigned char* const srcEndPtr = srcDataPtr + m_SrcShape.GetNumElements() * dataTypeSize;
67 unsigned char* const dstEndPtr = dstDataPtr + m_SrcShape.GetNumElements() * dataTypeSize;
68
69 Unroll(0, srcDataPtr, dstDataPtr, srcEndPtr, dstEndPtr, dataTypeSize);
70 }
71
72private:
73 void Unroll(size_type dimension,
74 const unsigned char* srcData, unsigned char* dstData,
75 const unsigned char* srcEnd, unsigned char* dstEnd,
76 size_t dataTypeSize)
77 {
David Monahan6a1d5062023-08-29 09:10:50 +010078 if (srcData == nullptr)
79 {
80 throw armnn::Exception("Transpose: Source Data pointer is null");
81 }
82 if (dstData == nullptr)
83 {
84 throw armnn::Exception("Transpose: Destination Data pointer is null");
85 }
86 if (srcEnd == nullptr)
87 {
88 throw armnn::Exception("Transpose: Source End pointer is null");
89 }
90 if (dstEnd == nullptr)
91 {
92 throw armnn::Exception("Transpose: Destination End is zero");
93 }
94 if (dataTypeSize == 0)
95 {
96 throw armnn::Exception("Transpose: dataTypeSize is invalid");
97 }
Mike Kellyc9ea45a2020-02-28 18:11:58 +000098
99 if (dimension >= m_SrcShape.GetNumDimensions())
100 {
101 ::memcpy(dstData, srcData, dataTypeSize);
102 }
103 else
104 {
105 for (size_type i = 0; i < m_SrcShape[dimension]; i++)
106 {
107 Unroll(dimension + 1, srcData, dstData, srcEnd, dstEnd, dataTypeSize);
108
109 srcData += m_SrcStrides[dimension] * dataTypeSize;
110 dstData += m_DstStrides[dimension] * dataTypeSize;
111 }
112 }
113 }
114
115 armnn::TensorShape m_SrcShape;
116 std::array<size_type, armnn::MaxNumOfTensorDimensions> m_SrcStrides;
117 std::array<size_type, armnn::MaxNumOfTensorDimensions> m_DstStrides;
118};
119
120} // namespace
121
122namespace armnnUtils
123{
124
125armnn::TensorShape TransposeTensorShape(const armnn::TensorShape& srcShape, const armnn::PermutationVector& mappings)
126{
David Monahan6a1d5062023-08-29 09:10:50 +0100127 if (srcShape.GetNumDimensions() != mappings.GetSize())
128 {
129 std::stringstream msg;
130 msg << "Transpose: Number of shape dimensions (" << srcShape.GetNumDimensions() <<
131 ") does not match the size of the mappings (" << mappings.GetSize() << ")";
132 throw armnn::InvalidArgumentException(msg.str());
133 }
Mike Kellyc9ea45a2020-02-28 18:11:58 +0000134
135 const unsigned int numDims = mappings.GetSize();
136 unsigned int outDims[armnn::MaxNumOfTensorDimensions];
137
138 for (unsigned int i = 0U; i < numDims; ++i)
139 {
140 outDims[i] = srcShape[mappings[i]];
141 }
142 armnn::TensorShape permutedShape(numDims, outDims);
143 return permutedShape;
144}
145
146armnn::TensorInfo TransposeTensorShape(const armnn::TensorInfo& info, const armnn::PermutationVector& mappings)
147{
148 armnn::TensorInfo outInfo(info);
149 outInfo.SetShape(TransposeTensorShape(info.GetShape(), mappings));
150 return outInfo;
151}
152
153void Transpose(const armnn::TensorShape& srcShape, const armnn::PermutationVector& mappings,
154 const void* src, void* dst, size_t dataTypeSize)
155{
156 TransposeLoop(srcShape, mappings).Unroll(src, dst, dataTypeSize);
157}
158
159} // namespace armnnUtils