blob: 19b465ba5d17a1f22e839647f2024010184884ac [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
telsoa014fcda012018-03-09 14:13:49 +00002// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
Matteo Martincighe011d202019-11-28 11:35:47 +00006#include <armnn/Tensor.hpp>
7
8#include <armnnUtils/Permute.hpp>
telsoa014fcda012018-03-09 14:13:49 +00009
arovir01616e7752018-10-01 17:08:59 +010010#include "Half.hpp"
telsoa014fcda012018-03-09 14:13:49 +000011
Matteo Martincigh747ef822018-12-18 09:26:39 +000012#include <cstring>
telsoa014fcda012018-03-09 14:13:49 +000013
14namespace
15{
16
17class PermuteLoop
18{
19public:
20 using size_type = unsigned int;
21
22 PermuteLoop(const armnn::TensorShape& dstShape, const armnn::PermutationVector& mappings)
23 : m_DstShape(dstShape)
24 {
David Monahan6a1d5062023-08-29 09:10:50 +010025 if (dstShape.GetNumDimensions() != mappings.GetSize())
26 {
27 std::stringstream msg;
28 msg << "Permute: Number of shape dimensions (" << dstShape.GetNumDimensions() <<
29 ") does not match the size of the mappings (" << mappings.GetSize() << ")";
30 throw armnn::InvalidArgumentException(msg.str());
31 }
telsoa014fcda012018-03-09 14:13:49 +000032
33 const size_type numDims = dstShape.GetNumDimensions();
34
35 size_type srcStride = 1U;
36 size_type dstStride = 1U;
37
38 for (size_type i = numDims - 1U, k = 0U; k < numDims; ++k, --i)
39 {
40 m_SrcStrides[mappings[i]] = srcStride;
41 m_DstStrides[i] = dstStride;
42
43 srcStride *= dstShape[mappings[i]];
44 dstStride *= dstShape[i];
45 }
46 }
47
Matteo Martincigh747ef822018-12-18 09:26:39 +000048 void Unroll(const void* srcData, void* dstData, size_t dataTypeSize)
49 {
David Monahan6a1d5062023-08-29 09:10:50 +010050 if (srcData == nullptr)
51 {
52 throw armnn::InvalidArgumentException("Permute: Source Data pointer is null");
53 }
54 if (dstData == nullptr)
55 {
56 throw armnn::InvalidArgumentException("Permute: Destination Data pointer is null");
57 }
58 if (dataTypeSize == 0)
59 {
60 throw armnn::InvalidArgumentException("Permute: dataTypeSize is zero");
61 }
Matteo Martincigh747ef822018-12-18 09:26:39 +000062
63 const unsigned char* srcDataPtr = reinterpret_cast<const unsigned char*>(srcData);
64 unsigned char* dstDataPtr = reinterpret_cast<unsigned char*>(dstData);
65
66 const unsigned char* const srcEndPtr = srcDataPtr + m_DstShape.GetNumElements() * dataTypeSize;
67 unsigned char* const dstEndPtr = dstDataPtr + m_DstShape.GetNumElements() * dataTypeSize;
68
69 Unroll(0, srcDataPtr, dstDataPtr, srcEndPtr, dstEndPtr, dataTypeSize);
70 }
71
telsoa014fcda012018-03-09 14:13:49 +000072private:
Matteo Martincigh747ef822018-12-18 09:26:39 +000073 void Unroll(size_type dimension,
74 const unsigned char* srcData, unsigned char* dstData,
75 const unsigned char* srcEnd, unsigned char* dstEnd,
76 size_t dataTypeSize)
77 {
David Monahan6a1d5062023-08-29 09:10:50 +010078 if (srcData == nullptr)
79 {
80 throw armnn::InvalidArgumentException("Permute: Source Data pointer is null");
81 }
82 if (dstData == nullptr)
83 {
84 throw armnn::InvalidArgumentException("Permute: Destination Data pointer is null");
85 }
86 if (srcEnd == nullptr)
87 {
88 throw armnn::InvalidArgumentException("Permute: Source End pointer is null");
89 }
90 if (dstEnd == nullptr)
91 {
92 throw armnn::InvalidArgumentException("Permute: Destination End pointer is null");
93 }
94 if (dataTypeSize == 0)
95 {
96 throw armnn::Exception("Permute: dataTypeSize is zero");
97 }
Matteo Martincigh747ef822018-12-18 09:26:39 +000098
99 if (dimension >= m_DstShape.GetNumDimensions())
100 {
101 ::memcpy(dstData, srcData, dataTypeSize);
102 }
103 else
104 {
105 for (size_type i = 0; i < m_DstShape[dimension]; i++)
106 {
107 Unroll(dimension + 1, srcData, dstData, srcEnd, dstEnd, dataTypeSize);
108
109 srcData += m_SrcStrides[dimension] * dataTypeSize;
110 dstData += m_DstStrides[dimension] * dataTypeSize;
111 }
112 }
113 }
114
telsoa014fcda012018-03-09 14:13:49 +0000115 armnn::TensorShape m_DstShape;
116 std::array<size_type, armnn::MaxNumOfTensorDimensions> m_SrcStrides;
117 std::array<size_type, armnn::MaxNumOfTensorDimensions> m_DstStrides;
118};
119
120} // namespace
121
122namespace armnnUtils
123{
124
Francis Murtagh0fe73762020-08-20 15:38:29 +0100125armnn::TensorShape Permuted(const armnn::TensorShape& srcShape,
126 const armnn::PermutationVector& mappings)
telsoa014fcda012018-03-09 14:13:49 +0000127{
David Monahan6a1d5062023-08-29 09:10:50 +0100128 if (srcShape.GetNumDimensions() != mappings.GetSize())
129 {
130 std::stringstream msg;
131 msg << "Permute: Number of shape dimensions (" << srcShape.GetNumDimensions() <<
132 ") does not match the size of the mappings (" << mappings.GetSize() << ")";
133 throw armnn::InvalidArgumentException(msg.str());
134 }
telsoa014fcda012018-03-09 14:13:49 +0000135
136 const unsigned int numDims = mappings.GetSize();
137 unsigned int outDims[armnn::MaxNumOfTensorDimensions];
138
139 for (unsigned int i = 0U; i < numDims; ++i)
140 {
141 outDims[mappings[i]] = srcShape[i];
142 }
143
144 armnn::TensorShape permutedShape(numDims, outDims);
145 return permutedShape;
146}
147
Francis Murtagh0fe73762020-08-20 15:38:29 +0100148armnn::TensorInfo Permuted(const armnn::TensorInfo& info,
Jan Eilers7612bd62021-04-06 17:29:03 +0100149 const armnn::PermutationVector& mappings)
telsoa014fcda012018-03-09 14:13:49 +0000150{
151 armnn::TensorInfo outInfo(info);
152 outInfo.SetShape(Permuted(info.GetShape(), mappings));
Francis Murtagh0fe73762020-08-20 15:38:29 +0100153
Jan Eilers7612bd62021-04-06 17:29:03 +0100154 // If TensorInfo has Per-Axis Quantization then it also has a QuantizationDim which needs to
155 // be permuted according to the mapping
156 if (info.GetQuantizationDim().has_value())
Francis Murtagh0fe73762020-08-20 15:38:29 +0100157 {
158 outInfo.SetQuantizationDim(mappings[info.GetQuantizationDim().value()]);
159 }
160
telsoa014fcda012018-03-09 14:13:49 +0000161 return outInfo;
162}
163
Matteo Martincigh747ef822018-12-18 09:26:39 +0000164void Permute(const armnn::TensorShape& dstShape, const armnn::PermutationVector& mappings,
165 const void* src, void* dst, size_t dataTypeSize)
166{
167 PermuteLoop(dstShape, mappings).Unroll(src, dst, dataTypeSize);
168}
169
telsoa014fcda012018-03-09 14:13:49 +0000170} // namespace armnnUtils