blob: 24a8286bba0f9469ee983bb39fb6b1ca9093b68a [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
6#include "Permute.hpp"
7
arovir01616e7752018-10-01 17:08:59 +01008#include "Half.hpp"
telsoa014fcda012018-03-09 14:13:49 +00009#include <armnn/Tensor.hpp>
10
11#include <cassert>
Matteo Martincigh747ef822018-12-18 09:26:39 +000012#include <cstring>
telsoa014fcda012018-03-09 14:13:49 +000013
14namespace
15{
16
17class PermuteLoop
18{
19public:
20 using size_type = unsigned int;
21
22 PermuteLoop(const armnn::TensorShape& dstShape, const armnn::PermutationVector& mappings)
23 : m_DstShape(dstShape)
24 {
25 assert(dstShape.GetNumDimensions() == mappings.GetSize());
26
27 const size_type numDims = dstShape.GetNumDimensions();
28
29 size_type srcStride = 1U;
30 size_type dstStride = 1U;
31
32 for (size_type i = numDims - 1U, k = 0U; k < numDims; ++k, --i)
33 {
34 m_SrcStrides[mappings[i]] = srcStride;
35 m_DstStrides[i] = dstStride;
36
37 srcStride *= dstShape[mappings[i]];
38 dstStride *= dstShape[i];
39 }
40 }
41
Matteo Martincigh747ef822018-12-18 09:26:39 +000042 void Unroll(const void* srcData, void* dstData, size_t dataTypeSize)
43 {
44 assert(srcData);
45 assert(dstData);
46 assert(dataTypeSize > 0);
47
48 const unsigned char* srcDataPtr = reinterpret_cast<const unsigned char*>(srcData);
49 unsigned char* dstDataPtr = reinterpret_cast<unsigned char*>(dstData);
50
51 const unsigned char* const srcEndPtr = srcDataPtr + m_DstShape.GetNumElements() * dataTypeSize;
52 unsigned char* const dstEndPtr = dstDataPtr + m_DstShape.GetNumElements() * dataTypeSize;
53
54 Unroll(0, srcDataPtr, dstDataPtr, srcEndPtr, dstEndPtr, dataTypeSize);
55 }
56
telsoa014fcda012018-03-09 14:13:49 +000057private:
Matteo Martincigh747ef822018-12-18 09:26:39 +000058 void Unroll(size_type dimension,
59 const unsigned char* srcData, unsigned char* dstData,
60 const unsigned char* srcEnd, unsigned char* dstEnd,
61 size_t dataTypeSize)
62 {
63 assert(srcData);
64 assert(dstData);
65 assert(srcEnd);
66 assert(dstEnd);
67 assert(srcData < srcEnd);
68 assert(dstData < dstEnd);
69 assert(dataTypeSize > 0);
70
71 if (dimension >= m_DstShape.GetNumDimensions())
72 {
73 ::memcpy(dstData, srcData, dataTypeSize);
74 }
75 else
76 {
77 for (size_type i = 0; i < m_DstShape[dimension]; i++)
78 {
79 Unroll(dimension + 1, srcData, dstData, srcEnd, dstEnd, dataTypeSize);
80
81 srcData += m_SrcStrides[dimension] * dataTypeSize;
82 dstData += m_DstStrides[dimension] * dataTypeSize;
83 }
84 }
85 }
86
telsoa014fcda012018-03-09 14:13:49 +000087 armnn::TensorShape m_DstShape;
88 std::array<size_type, armnn::MaxNumOfTensorDimensions> m_SrcStrides;
89 std::array<size_type, armnn::MaxNumOfTensorDimensions> m_DstStrides;
90};
91
92} // namespace
93
94namespace armnnUtils
95{
96
97armnn::TensorShape Permuted(const armnn::TensorShape& srcShape, const armnn::PermutationVector& mappings)
98{
99 assert(srcShape.GetNumDimensions() == mappings.GetSize());
100
101 const unsigned int numDims = mappings.GetSize();
102 unsigned int outDims[armnn::MaxNumOfTensorDimensions];
103
104 for (unsigned int i = 0U; i < numDims; ++i)
105 {
106 outDims[mappings[i]] = srcShape[i];
107 }
108
109 armnn::TensorShape permutedShape(numDims, outDims);
110 return permutedShape;
111}
112
113armnn::TensorInfo Permuted(const armnn::TensorInfo& info, const armnn::PermutationVector& mappings)
114{
115 armnn::TensorInfo outInfo(info);
116 outInfo.SetShape(Permuted(info.GetShape(), mappings));
117 return outInfo;
118}
119
Matteo Martincigh747ef822018-12-18 09:26:39 +0000120void Permute(const armnn::TensorShape& dstShape, const armnn::PermutationVector& mappings,
121 const void* src, void* dst, size_t dataTypeSize)
122{
123 PermuteLoop(dstShape, mappings).Unroll(src, dst, dataTypeSize);
124}
125
telsoa014fcda012018-03-09 14:13:49 +0000126} // namespace armnnUtils