blob: 3457cacd65af6aded3257007613f7f71516f9217 [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002// Copyright © 2020 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include <armnn/Tensor.hpp>
7
8#include <armnnUtils/Transpose.hpp>
9
10#include "Half.hpp"
11
12#include <cassert>
13#include <cstring>
14
15namespace
16{
17
18class TransposeLoop
19{
20public:
21 using size_type = unsigned int;
22
23 TransposeLoop(const armnn::TensorShape& srcShape, const armnn::PermutationVector& mappings)
24 : m_SrcShape(srcShape)
25 {
26 assert(srcShape.GetNumDimensions() == mappings.GetSize());
27
28 const size_type numDims = srcShape.GetNumDimensions();
29
30 size_type srcStride = 1U;
31 size_type dstStride = 1U;
32
33 for (size_type i = numDims - 1U, k = 0U; k < numDims; ++k, --i)
34 {
35 m_SrcStrides[i] = srcStride;
36 m_DstStrides[mappings[i]] = dstStride;
37
38 srcStride *= srcShape[i];
39 dstStride *= srcShape[mappings[i]];
40 }
41 }
42
43 void Unroll(const void* srcData, void* dstData, size_t dataTypeSize)
44 {
45 assert(srcData);
46 assert(dstData);
47 assert(dataTypeSize > 0);
48
49 const unsigned char* srcDataPtr = reinterpret_cast<const unsigned char*>(srcData);
50 unsigned char* dstDataPtr = reinterpret_cast<unsigned char*>(dstData);
51
52 const unsigned char* const srcEndPtr = srcDataPtr + m_SrcShape.GetNumElements() * dataTypeSize;
53 unsigned char* const dstEndPtr = dstDataPtr + m_SrcShape.GetNumElements() * dataTypeSize;
54
55 Unroll(0, srcDataPtr, dstDataPtr, srcEndPtr, dstEndPtr, dataTypeSize);
56 }
57
58private:
59 void Unroll(size_type dimension,
60 const unsigned char* srcData, unsigned char* dstData,
61 const unsigned char* srcEnd, unsigned char* dstEnd,
62 size_t dataTypeSize)
63 {
64 assert(srcData);
65 assert(dstData);
66 assert(srcEnd);
67 assert(dstEnd);
68 assert(srcData < srcEnd);
69 assert(dstData < dstEnd);
70 assert(dataTypeSize > 0);
71
72 if (dimension >= m_SrcShape.GetNumDimensions())
73 {
74 ::memcpy(dstData, srcData, dataTypeSize);
75 }
76 else
77 {
78 for (size_type i = 0; i < m_SrcShape[dimension]; i++)
79 {
80 Unroll(dimension + 1, srcData, dstData, srcEnd, dstEnd, dataTypeSize);
81
82 srcData += m_SrcStrides[dimension] * dataTypeSize;
83 dstData += m_DstStrides[dimension] * dataTypeSize;
84 }
85 }
86 }
87
88 armnn::TensorShape m_SrcShape;
89 std::array<size_type, armnn::MaxNumOfTensorDimensions> m_SrcStrides;
90 std::array<size_type, armnn::MaxNumOfTensorDimensions> m_DstStrides;
91};
92
93} // namespace
94
95namespace armnnUtils
96{
97
98armnn::TensorShape TransposeTensorShape(const armnn::TensorShape& srcShape, const armnn::PermutationVector& mappings)
99{
100 assert(srcShape.GetNumDimensions() == mappings.GetSize());
101
102 const unsigned int numDims = mappings.GetSize();
103 unsigned int outDims[armnn::MaxNumOfTensorDimensions];
104
105 for (unsigned int i = 0U; i < numDims; ++i)
106 {
107 outDims[i] = srcShape[mappings[i]];
108 }
109 armnn::TensorShape permutedShape(numDims, outDims);
110 return permutedShape;
111}
112
113armnn::TensorInfo TransposeTensorShape(const armnn::TensorInfo& info, const armnn::PermutationVector& mappings)
114{
115 armnn::TensorInfo outInfo(info);
116 outInfo.SetShape(TransposeTensorShape(info.GetShape(), mappings));
117 return outInfo;
118}
119
120void Transpose(const armnn::TensorShape& srcShape, const armnn::PermutationVector& mappings,
121 const void* src, void* dst, size_t dataTypeSize)
122{
123 TransposeLoop(srcShape, mappings).Unroll(src, dst, dataTypeSize);
124}
125
126} // namespace armnnUtils