blob: 6deff9016862ae7b6bab185934fb4cfe248ef90d [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
6#include "Permute.hpp"
7
arovir01616e7752018-10-01 17:08:59 +01008#include "Half.hpp"
telsoa014fcda012018-03-09 14:13:49 +00009#include <armnn/Tensor.hpp>
10
11#include <cassert>
Matteo Martincigh747ef822018-12-18 09:26:39 +000012#include <cstring>
telsoa014fcda012018-03-09 14:13:49 +000013
14namespace
15{
16
17class PermuteLoop
18{
19public:
20 using size_type = unsigned int;
21
22 PermuteLoop(const armnn::TensorShape& dstShape, const armnn::PermutationVector& mappings)
23 : m_DstShape(dstShape)
24 {
25 assert(dstShape.GetNumDimensions() == mappings.GetSize());
26
27 const size_type numDims = dstShape.GetNumDimensions();
28
29 size_type srcStride = 1U;
30 size_type dstStride = 1U;
31
32 for (size_type i = numDims - 1U, k = 0U; k < numDims; ++k, --i)
33 {
34 m_SrcStrides[mappings[i]] = srcStride;
35 m_DstStrides[i] = dstStride;
36
37 srcStride *= dstShape[mappings[i]];
38 dstStride *= dstShape[i];
39 }
40 }
41
42 template <typename T>
43 void Unroll(const T* srcData, T* dstData)
44 {
45 const T* const srcEnd = srcData + m_DstShape.GetNumElements();
46 T* const dstEnd = dstData + m_DstShape.GetNumElements();
47 Unroll(0, srcData, dstData, srcEnd, dstEnd);
48 }
49
Matteo Martincigh747ef822018-12-18 09:26:39 +000050 void Unroll(const void* srcData, void* dstData, size_t dataTypeSize)
51 {
52 assert(srcData);
53 assert(dstData);
54 assert(dataTypeSize > 0);
55
56 const unsigned char* srcDataPtr = reinterpret_cast<const unsigned char*>(srcData);
57 unsigned char* dstDataPtr = reinterpret_cast<unsigned char*>(dstData);
58
59 const unsigned char* const srcEndPtr = srcDataPtr + m_DstShape.GetNumElements() * dataTypeSize;
60 unsigned char* const dstEndPtr = dstDataPtr + m_DstShape.GetNumElements() * dataTypeSize;
61
62 Unroll(0, srcDataPtr, dstDataPtr, srcEndPtr, dstEndPtr, dataTypeSize);
63 }
64
telsoa014fcda012018-03-09 14:13:49 +000065private:
66 template <typename T>
67 void Unroll(size_type dimension, const T* srcData, T* dstData, const T* srcEnd, T* dstEnd)
68 {
Matteo Martincigh747ef822018-12-18 09:26:39 +000069 assert(srcData);
70 assert(dstData);
71 assert(srcEnd);
72 assert(dstEnd);
telsoa014fcda012018-03-09 14:13:49 +000073 assert(srcData < srcEnd);
74 assert(dstData < dstEnd);
75
76 if (dimension >= m_DstShape.GetNumDimensions())
77 {
78 *dstData = *srcData;
79 }
80 else
81 {
82 for (size_type i = 0; i < m_DstShape[dimension]; i++)
83 {
84 Unroll(dimension + 1, srcData, dstData, srcEnd, dstEnd);
85
86 srcData += m_SrcStrides[dimension];
87 dstData += m_DstStrides[dimension];
88 }
89 }
90 }
91
Matteo Martincigh747ef822018-12-18 09:26:39 +000092 void Unroll(size_type dimension,
93 const unsigned char* srcData, unsigned char* dstData,
94 const unsigned char* srcEnd, unsigned char* dstEnd,
95 size_t dataTypeSize)
96 {
97 assert(srcData);
98 assert(dstData);
99 assert(srcEnd);
100 assert(dstEnd);
101 assert(srcData < srcEnd);
102 assert(dstData < dstEnd);
103 assert(dataTypeSize > 0);
104
105 if (dimension >= m_DstShape.GetNumDimensions())
106 {
107 ::memcpy(dstData, srcData, dataTypeSize);
108 }
109 else
110 {
111 for (size_type i = 0; i < m_DstShape[dimension]; i++)
112 {
113 Unroll(dimension + 1, srcData, dstData, srcEnd, dstEnd, dataTypeSize);
114
115 srcData += m_SrcStrides[dimension] * dataTypeSize;
116 dstData += m_DstStrides[dimension] * dataTypeSize;
117 }
118 }
119 }
120
telsoa014fcda012018-03-09 14:13:49 +0000121 armnn::TensorShape m_DstShape;
122 std::array<size_type, armnn::MaxNumOfTensorDimensions> m_SrcStrides;
123 std::array<size_type, armnn::MaxNumOfTensorDimensions> m_DstStrides;
124};
125
126} // namespace
127
128namespace armnnUtils
129{
130
131armnn::TensorShape Permuted(const armnn::TensorShape& srcShape, const armnn::PermutationVector& mappings)
132{
133 assert(srcShape.GetNumDimensions() == mappings.GetSize());
134
135 const unsigned int numDims = mappings.GetSize();
136 unsigned int outDims[armnn::MaxNumOfTensorDimensions];
137
138 for (unsigned int i = 0U; i < numDims; ++i)
139 {
140 outDims[mappings[i]] = srcShape[i];
141 }
142
143 armnn::TensorShape permutedShape(numDims, outDims);
144 return permutedShape;
145}
146
147armnn::TensorInfo Permuted(const armnn::TensorInfo& info, const armnn::PermutationVector& mappings)
148{
149 armnn::TensorInfo outInfo(info);
150 outInfo.SetShape(Permuted(info.GetShape(), mappings));
151 return outInfo;
152}
153
Matteo Martincigh747ef822018-12-18 09:26:39 +0000154void Permute(const armnn::TensorShape& dstShape, const armnn::PermutationVector& mappings,
155 const void* src, void* dst, size_t dataTypeSize)
156{
157 PermuteLoop(dstShape, mappings).Unroll(src, dst, dataTypeSize);
158}
159
telsoa014fcda012018-03-09 14:13:49 +0000160template <typename T>
161void Permute(const armnn::TensorShape& dstShape, const armnn::PermutationVector& mappings, const T* src, T* dst)
162{
163 PermuteLoop(dstShape, mappings).Unroll(src, dst);
164}
165
telsoa01c577f2c2018-08-31 09:22:23 +0100166// Instantiates for types.
telsoa014fcda012018-03-09 14:13:49 +0000167template void Permute(const armnn::TensorShape& dstShape, const armnn::PermutationVector& mappings,
arovir01616e7752018-10-01 17:08:59 +0100168 const armnn::Half* src, armnn::Half* dst);
169template void Permute(const armnn::TensorShape& dstShape, const armnn::PermutationVector& mappings,
telsoa014fcda012018-03-09 14:13:49 +0000170 const float* src, float* dst);
171template void Permute(const armnn::TensorShape& dstShape, const armnn::PermutationVector& mappings,
172 const uint8_t* src, uint8_t* dst);
173template void Permute(const armnn::TensorShape& dstShape, const armnn::PermutationVector& mappings,
174 const int32_t* src, int32_t* dst);
Matteo Martincigh747ef822018-12-18 09:26:39 +0000175template void Permute(const armnn::TensorShape& dstShape, const armnn::PermutationVector& mappings,
176 const bool* src, bool* dst);
telsoa014fcda012018-03-09 14:13:49 +0000177
178} // namespace armnnUtils