blob: 58e58583fca129e1f529bfefee6975681625f0a0 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// See LICENSE file in the project root for full license information.
4//
5
6#include "Permute.hpp"
7
8#include <armnn/Tensor.hpp>
9
10#include <cassert>
11
12namespace
13{
14
15class PermuteLoop
16{
17public:
18 using size_type = unsigned int;
19
20 PermuteLoop(const armnn::TensorShape& dstShape, const armnn::PermutationVector& mappings)
21 : m_DstShape(dstShape)
22 {
23 assert(dstShape.GetNumDimensions() == mappings.GetSize());
24
25 const size_type numDims = dstShape.GetNumDimensions();
26
27 size_type srcStride = 1U;
28 size_type dstStride = 1U;
29
30 for (size_type i = numDims - 1U, k = 0U; k < numDims; ++k, --i)
31 {
32 m_SrcStrides[mappings[i]] = srcStride;
33 m_DstStrides[i] = dstStride;
34
35 srcStride *= dstShape[mappings[i]];
36 dstStride *= dstShape[i];
37 }
38 }
39
40 template <typename T>
41 void Unroll(const T* srcData, T* dstData)
42 {
43 const T* const srcEnd = srcData + m_DstShape.GetNumElements();
44 T* const dstEnd = dstData + m_DstShape.GetNumElements();
45 Unroll(0, srcData, dstData, srcEnd, dstEnd);
46 }
47
48private:
49 template <typename T>
50 void Unroll(size_type dimension, const T* srcData, T* dstData, const T* srcEnd, T* dstEnd)
51 {
52 assert(srcData < srcEnd);
53 assert(dstData < dstEnd);
54
55 if (dimension >= m_DstShape.GetNumDimensions())
56 {
57 *dstData = *srcData;
58 }
59 else
60 {
61 for (size_type i = 0; i < m_DstShape[dimension]; i++)
62 {
63 Unroll(dimension + 1, srcData, dstData, srcEnd, dstEnd);
64
65 srcData += m_SrcStrides[dimension];
66 dstData += m_DstStrides[dimension];
67 }
68 }
69 }
70
71 armnn::TensorShape m_DstShape;
72 std::array<size_type, armnn::MaxNumOfTensorDimensions> m_SrcStrides;
73 std::array<size_type, armnn::MaxNumOfTensorDimensions> m_DstStrides;
74};
75
76} // namespace
77
78namespace armnnUtils
79{
80
81armnn::TensorShape Permuted(const armnn::TensorShape& srcShape, const armnn::PermutationVector& mappings)
82{
83 assert(srcShape.GetNumDimensions() == mappings.GetSize());
84
85 const unsigned int numDims = mappings.GetSize();
86 unsigned int outDims[armnn::MaxNumOfTensorDimensions];
87
88 for (unsigned int i = 0U; i < numDims; ++i)
89 {
90 outDims[mappings[i]] = srcShape[i];
91 }
92
93 armnn::TensorShape permutedShape(numDims, outDims);
94 return permutedShape;
95}
96
97armnn::TensorInfo Permuted(const armnn::TensorInfo& info, const armnn::PermutationVector& mappings)
98{
99 armnn::TensorInfo outInfo(info);
100 outInfo.SetShape(Permuted(info.GetShape(), mappings));
101 return outInfo;
102}
103
104template <typename T>
105void Permute(const armnn::TensorShape& dstShape, const armnn::PermutationVector& mappings, const T* src, T* dst)
106{
107 PermuteLoop(dstShape, mappings).Unroll(src, dst);
108}
109
110// Instantiate for types
111template void Permute(const armnn::TensorShape& dstShape, const armnn::PermutationVector& mappings,
112 const float* src, float* dst);
113template void Permute(const armnn::TensorShape& dstShape, const armnn::PermutationVector& mappings,
114 const uint8_t* src, uint8_t* dst);
115template void Permute(const armnn::TensorShape& dstShape, const armnn::PermutationVector& mappings,
116 const int32_t* src, int32_t* dst);
117
118} // namespace armnnUtils