blob: 61f4e0e644144e0c0ee0661faf2defa6c7761730 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
6#include "Permute.hpp"
7
arovir01616e7752018-10-01 17:08:59 +01008#include "Half.hpp"
telsoa014fcda012018-03-09 14:13:49 +00009#include <armnn/Tensor.hpp>
10
11#include <cassert>
12
13namespace
14{
15
16class PermuteLoop
17{
18public:
19 using size_type = unsigned int;
20
21 PermuteLoop(const armnn::TensorShape& dstShape, const armnn::PermutationVector& mappings)
22 : m_DstShape(dstShape)
23 {
24 assert(dstShape.GetNumDimensions() == mappings.GetSize());
25
26 const size_type numDims = dstShape.GetNumDimensions();
27
28 size_type srcStride = 1U;
29 size_type dstStride = 1U;
30
31 for (size_type i = numDims - 1U, k = 0U; k < numDims; ++k, --i)
32 {
33 m_SrcStrides[mappings[i]] = srcStride;
34 m_DstStrides[i] = dstStride;
35
36 srcStride *= dstShape[mappings[i]];
37 dstStride *= dstShape[i];
38 }
39 }
40
41 template <typename T>
42 void Unroll(const T* srcData, T* dstData)
43 {
44 const T* const srcEnd = srcData + m_DstShape.GetNumElements();
45 T* const dstEnd = dstData + m_DstShape.GetNumElements();
46 Unroll(0, srcData, dstData, srcEnd, dstEnd);
47 }
48
49private:
50 template <typename T>
51 void Unroll(size_type dimension, const T* srcData, T* dstData, const T* srcEnd, T* dstEnd)
52 {
53 assert(srcData < srcEnd);
54 assert(dstData < dstEnd);
55
56 if (dimension >= m_DstShape.GetNumDimensions())
57 {
58 *dstData = *srcData;
59 }
60 else
61 {
62 for (size_type i = 0; i < m_DstShape[dimension]; i++)
63 {
64 Unroll(dimension + 1, srcData, dstData, srcEnd, dstEnd);
65
66 srcData += m_SrcStrides[dimension];
67 dstData += m_DstStrides[dimension];
68 }
69 }
70 }
71
72 armnn::TensorShape m_DstShape;
73 std::array<size_type, armnn::MaxNumOfTensorDimensions> m_SrcStrides;
74 std::array<size_type, armnn::MaxNumOfTensorDimensions> m_DstStrides;
75};
76
77} // namespace
78
79namespace armnnUtils
80{
81
82armnn::TensorShape Permuted(const armnn::TensorShape& srcShape, const armnn::PermutationVector& mappings)
83{
84 assert(srcShape.GetNumDimensions() == mappings.GetSize());
85
86 const unsigned int numDims = mappings.GetSize();
87 unsigned int outDims[armnn::MaxNumOfTensorDimensions];
88
89 for (unsigned int i = 0U; i < numDims; ++i)
90 {
91 outDims[mappings[i]] = srcShape[i];
92 }
93
94 armnn::TensorShape permutedShape(numDims, outDims);
95 return permutedShape;
96}
97
98armnn::TensorInfo Permuted(const armnn::TensorInfo& info, const armnn::PermutationVector& mappings)
99{
100 armnn::TensorInfo outInfo(info);
101 outInfo.SetShape(Permuted(info.GetShape(), mappings));
102 return outInfo;
103}
104
105template <typename T>
106void Permute(const armnn::TensorShape& dstShape, const armnn::PermutationVector& mappings, const T* src, T* dst)
107{
108 PermuteLoop(dstShape, mappings).Unroll(src, dst);
109}
110
telsoa01c577f2c2018-08-31 09:22:23 +0100111// Instantiates for types.
telsoa014fcda012018-03-09 14:13:49 +0000112template void Permute(const armnn::TensorShape& dstShape, const armnn::PermutationVector& mappings,
arovir01616e7752018-10-01 17:08:59 +0100113 const armnn::Half* src, armnn::Half* dst);
114template void Permute(const armnn::TensorShape& dstShape, const armnn::PermutationVector& mappings,
telsoa014fcda012018-03-09 14:13:49 +0000115 const float* src, float* dst);
116template void Permute(const armnn::TensorShape& dstShape, const armnn::PermutationVector& mappings,
117 const uint8_t* src, uint8_t* dst);
118template void Permute(const armnn::TensorShape& dstShape, const armnn::PermutationVector& mappings,
119 const int32_t* src, int32_t* dst);
120
121} // namespace armnnUtils