blob: 377046367c93a781c0b06ff177c21d63876656b9 [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
telsoa014fcda012018-03-09 14:13:49 +00002// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
Matteo Martincighe011d202019-11-28 11:35:47 +00006#include <armnn/Tensor.hpp>
7
8#include <armnnUtils/Permute.hpp>
telsoa014fcda012018-03-09 14:13:49 +00009
arovir01616e7752018-10-01 17:08:59 +010010#include "Half.hpp"
telsoa014fcda012018-03-09 14:13:49 +000011
12#include <cassert>
Matteo Martincigh747ef822018-12-18 09:26:39 +000013#include <cstring>
telsoa014fcda012018-03-09 14:13:49 +000014
15namespace
16{
17
18class PermuteLoop
19{
20public:
21 using size_type = unsigned int;
22
23 PermuteLoop(const armnn::TensorShape& dstShape, const armnn::PermutationVector& mappings)
24 : m_DstShape(dstShape)
25 {
26 assert(dstShape.GetNumDimensions() == mappings.GetSize());
27
28 const size_type numDims = dstShape.GetNumDimensions();
29
30 size_type srcStride = 1U;
31 size_type dstStride = 1U;
32
33 for (size_type i = numDims - 1U, k = 0U; k < numDims; ++k, --i)
34 {
35 m_SrcStrides[mappings[i]] = srcStride;
36 m_DstStrides[i] = dstStride;
37
38 srcStride *= dstShape[mappings[i]];
39 dstStride *= dstShape[i];
40 }
41 }
42
Matteo Martincigh747ef822018-12-18 09:26:39 +000043 void Unroll(const void* srcData, void* dstData, size_t dataTypeSize)
44 {
45 assert(srcData);
46 assert(dstData);
47 assert(dataTypeSize > 0);
48
49 const unsigned char* srcDataPtr = reinterpret_cast<const unsigned char*>(srcData);
50 unsigned char* dstDataPtr = reinterpret_cast<unsigned char*>(dstData);
51
52 const unsigned char* const srcEndPtr = srcDataPtr + m_DstShape.GetNumElements() * dataTypeSize;
53 unsigned char* const dstEndPtr = dstDataPtr + m_DstShape.GetNumElements() * dataTypeSize;
54
55 Unroll(0, srcDataPtr, dstDataPtr, srcEndPtr, dstEndPtr, dataTypeSize);
56 }
57
telsoa014fcda012018-03-09 14:13:49 +000058private:
Matteo Martincigh747ef822018-12-18 09:26:39 +000059 void Unroll(size_type dimension,
60 const unsigned char* srcData, unsigned char* dstData,
61 const unsigned char* srcEnd, unsigned char* dstEnd,
62 size_t dataTypeSize)
63 {
64 assert(srcData);
65 assert(dstData);
66 assert(srcEnd);
67 assert(dstEnd);
68 assert(srcData < srcEnd);
69 assert(dstData < dstEnd);
70 assert(dataTypeSize > 0);
71
72 if (dimension >= m_DstShape.GetNumDimensions())
73 {
74 ::memcpy(dstData, srcData, dataTypeSize);
75 }
76 else
77 {
78 for (size_type i = 0; i < m_DstShape[dimension]; i++)
79 {
80 Unroll(dimension + 1, srcData, dstData, srcEnd, dstEnd, dataTypeSize);
81
82 srcData += m_SrcStrides[dimension] * dataTypeSize;
83 dstData += m_DstStrides[dimension] * dataTypeSize;
84 }
85 }
86 }
87
telsoa014fcda012018-03-09 14:13:49 +000088 armnn::TensorShape m_DstShape;
89 std::array<size_type, armnn::MaxNumOfTensorDimensions> m_SrcStrides;
90 std::array<size_type, armnn::MaxNumOfTensorDimensions> m_DstStrides;
91};
92
93} // namespace
94
95namespace armnnUtils
96{
97
Francis Murtagh0fe73762020-08-20 15:38:29 +010098armnn::TensorShape Permuted(const armnn::TensorShape& srcShape,
99 const armnn::PermutationVector& mappings)
telsoa014fcda012018-03-09 14:13:49 +0000100{
101 assert(srcShape.GetNumDimensions() == mappings.GetSize());
102
103 const unsigned int numDims = mappings.GetSize();
104 unsigned int outDims[armnn::MaxNumOfTensorDimensions];
105
106 for (unsigned int i = 0U; i < numDims; ++i)
107 {
108 outDims[mappings[i]] = srcShape[i];
109 }
110
111 armnn::TensorShape permutedShape(numDims, outDims);
112 return permutedShape;
113}
114
Francis Murtagh0fe73762020-08-20 15:38:29 +0100115armnn::TensorInfo Permuted(const armnn::TensorInfo& info,
116 const armnn::PermutationVector& mappings,
117 bool perChannelPermute)
telsoa014fcda012018-03-09 14:13:49 +0000118{
119 armnn::TensorInfo outInfo(info);
120 outInfo.SetShape(Permuted(info.GetShape(), mappings));
Francis Murtagh0fe73762020-08-20 15:38:29 +0100121
122 // If TensorInfo has Per-Axis Quantization then permute QuantizationDim to mapping
123 if (info.HasPerAxisQuantization() && perChannelPermute)
124 {
125 outInfo.SetQuantizationDim(mappings[info.GetQuantizationDim().value()]);
126 }
127
telsoa014fcda012018-03-09 14:13:49 +0000128 return outInfo;
129}
130
Matteo Martincigh747ef822018-12-18 09:26:39 +0000131void Permute(const armnn::TensorShape& dstShape, const armnn::PermutationVector& mappings,
132 const void* src, void* dst, size_t dataTypeSize)
133{
134 PermuteLoop(dstShape, mappings).Unroll(src, dst, dataTypeSize);
135}
136
telsoa014fcda012018-03-09 14:13:49 +0000137} // namespace armnnUtils