blob: f5e9ec549800d37fd18c9275ab0840cb392960d1 [file] [log] [blame]
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001//
2// Copyright © 2019 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "DepthToSpace.hpp"
7
Matteo Martincighe011d202019-11-28 11:35:47 +00008#include <armnnUtils/DataLayoutIndexed.hpp>
9#include <armnnUtils/Permute.hpp>
Aron Virginas-Tar73f66422019-09-23 19:11:59 +010010
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010011#include <armnn/utility/Assert.hpp>
Aron Virginas-Tar73f66422019-09-23 19:11:59 +010012
13using namespace armnnUtils;
14
15namespace armnn
16{
17
18void DepthToSpace(const TensorInfo& inputInfo,
19 const DepthToSpaceDescriptor& descriptor,
20 const void* inputData,
21 void* outputData,
22 unsigned int dataTypeSize)
23{
24 const unsigned int blockSize = descriptor.m_BlockSize;
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010025 ARMNN_ASSERT(blockSize != 0u);
Aron Virginas-Tar73f66422019-09-23 19:11:59 +010026
27 const TensorShape& inputShape = inputInfo.GetShape();
28 const unsigned int batches = inputShape[0];
29
30 armnnUtils::DataLayoutIndexed dataLayoutIndexed(descriptor.m_DataLayout);
31 const unsigned int inDepth = inputShape[dataLayoutIndexed.GetChannelsIndex()];
32 const unsigned int inHeight = inputShape[dataLayoutIndexed.GetHeightIndex()];
33 const unsigned int inWidth = inputShape[dataLayoutIndexed.GetWidthIndex()];
34
35 const unsigned int outDepth = inDepth / (blockSize * blockSize);
36
37 // The 4D input data can be interpreted as 6D (implicitly reshaped) as follows:
38 //
39 // [batch, block size, block size, inDepth, inHeight, inWidth] for NCHW and
40 // [batch, inHeight, inWidth, blockSize, blockSize, outDepth] for NHWC.
41 //
42 // DepthToSpace can then be implemented as a permutation in 6D resulting in
43 // the following shapes:
44 //
45 // [batch, outDepth, inHeight, blockSize, inWidth, blockSize] for NCHW and
46 // [batch, inHeight, blockSize, inWidth, blockSize, outDepth] for NHWC.
47 //
48 // NOTE:
49 // Since 6D tensors are not currently supported, in practice we need to handle each
50 // batch separately and execute 5D permutations
51
52 TensorShape permDestShape;
Aron Virginas-Tar9926e582019-09-25 12:44:53 +010053 PermutationVector permVector{};
Aron Virginas-Tar73f66422019-09-23 19:11:59 +010054 if (descriptor.m_DataLayout == DataLayout::NCHW)
55 {
56 permDestShape = TensorShape({ outDepth, inHeight, blockSize, inWidth, blockSize });
57 permVector = { 2, 4, 0, 1, 3 };
58 }
59 else
60 {
61 permDestShape = TensorShape({ inHeight, blockSize, inWidth, blockSize, outDepth });
62 permVector = { 0, 2, 1, 3, 4 };
63 }
64
65 const unsigned int numElementsPerBatch = inputShape.GetNumElements() / batches;
66
67 for (unsigned int batchIndex = 0u; batchIndex < batches; ++batchIndex)
68 {
69 const uintptr_t batchDataOffset = batchIndex * (numElementsPerBatch * dataTypeSize);
70
71 armnnUtils::Permute(permDestShape,
Aron Virginas-Tar9926e582019-09-25 12:44:53 +010072 permVector,
Aron Virginas-Tar73f66422019-09-23 19:11:59 +010073 static_cast<const void*>(reinterpret_cast<const uint8_t*>(inputData) + batchDataOffset),
74 static_cast<void*>(reinterpret_cast<uint8_t*>(outputData) + batchDataOffset),
75 dataTypeSize);
76 }
77}
78
79} // namespace armnn