blob: b6bab17367bfe11b18411a8e8656b2f37a998263 [file] [log] [blame]
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "SpaceToBatchNd.hpp"
7
Matteo Martincighe011d202019-11-28 11:35:47 +00008#include <armnnUtils/DataLayoutIndexed.hpp>
Matteo Martincigh21350152018-11-28 16:22:22 +00009
10using namespace armnnUtils;
Matthew Bentham8800c002018-11-19 13:19:28 +000011
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +000012namespace armnn
13{
14
15unsigned int GetOffset(const TensorShape& shape,
16 unsigned int b,
17 unsigned int h,
18 unsigned int w,
19 unsigned int c,
20 const DataLayoutIndexed& dataLayout)
21{
22 if (dataLayout.GetDataLayout() == DataLayout::NHWC)
23 {
24 return ((b * shape[dataLayout.GetHeightIndex()] + h) * shape[dataLayout.GetWidthIndex()] + w) *
25 shape[dataLayout.GetChannelsIndex()] + c;
26 }
27 else
28 {
29 return ((b * shape[dataLayout.GetChannelsIndex()] + c) * shape[dataLayout.GetHeightIndex()] + h) *
30 shape[dataLayout.GetWidthIndex()] + w;
31 }
32}
33
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +000034void SpaceToBatchNd(const TensorInfo& inputInfo,
35 const TensorInfo& outputInfo,
36 const SpaceToBatchNdDescriptor& params,
nikraj0122f0f2b2019-05-30 17:29:32 +010037 Decoder<float>& inputData,
38 Encoder<float>& outputData)
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +000039{
40 DataLayoutIndexed dataLayout = params.m_DataLayout;
41
42 const TensorShape& inputShape = inputInfo.GetShape();
43 const TensorShape& outputShape = outputInfo.GetShape();
44
45 const unsigned int channels = inputShape[dataLayout.GetChannelsIndex()];
46
47 const unsigned int inputBatchSize = inputShape[0];
48 const unsigned int inputHeight = inputShape[dataLayout.GetHeightIndex()];
49 const unsigned int inputWidth = inputShape[dataLayout.GetWidthIndex()];
50
51 const unsigned int outputBatchSize = outputShape[0];
52 const unsigned int outputHeight = outputShape[dataLayout.GetHeightIndex()];
53 const unsigned int outputWidth = outputShape[dataLayout.GetWidthIndex()];
54
55 const unsigned int blockHeight = params.m_BlockShape[0];
56 const unsigned int blockWidth = params.m_BlockShape[1];
57
58 const unsigned int paddingTop = params.m_PadList[0].first;
59 const unsigned int paddingLeft = params.m_PadList[1].first;
60
61 for (unsigned int outB = 0; outB < outputBatchSize; outB++)
62 {
63 unsigned int inB = outB % inputBatchSize;
64
65 unsigned int shiftW = (outB / inputBatchSize) % blockWidth;
66 unsigned int shiftH = (outB / inputBatchSize) / blockWidth;
67
68 for (unsigned int outH = 0; outH < outputHeight; outH++)
69 {
70 for (unsigned int outW = 0; outW < outputWidth; outW++)
71 {
72 if (outH * blockHeight + shiftH < paddingTop ||
73 outH * blockHeight + shiftH >= paddingTop + inputHeight ||
74 outW * blockWidth + shiftW < paddingLeft ||
75 outW * blockWidth + shiftW >= paddingLeft + inputWidth)
76 {
77 for (unsigned int c = 0; c < channels; c++)
78 {
79 unsigned int outOffset = GetOffset(outputShape,
80 outB,
81 outH,
82 outW,
83 c,
84 dataLayout);
nikraj0122f0f2b2019-05-30 17:29:32 +010085 outputData += outOffset;
86 outputData.Set(0);
87 outputData -= outOffset;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +000088 }
89 }
90 else
91 {
92 for (unsigned int c = 0; c < channels; c++)
93 {
94 unsigned int inOffset = GetOffset(inputShape,
95 inB,
96 (outH * blockHeight + shiftH) - paddingTop,
97 (outW * blockWidth + shiftW) - paddingLeft,
98 c,
99 dataLayout);
100
101 unsigned int outOffset = GetOffset(outputShape,
102 outB,
103 outH,
104 outW,
105 c,
106 dataLayout);
107
nikraj0122f0f2b2019-05-30 17:29:32 +0100108 outputData += outOffset;
109 inputData += inOffset;
110 outputData.Set(inputData.Get());
111 inputData -= inOffset;
112 outputData -= outOffset;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +0000113 }
114 }
115 }
116 }
117 }
118}
119
nikraj0122f0f2b2019-05-30 17:29:32 +0100120void SpaceToBatchNd(const TensorInfo& inputInfo,
121 const TensorInfo& outputInfo,
122 const SpaceToBatchNdDescriptor& params,
123 Decoder<float>& inputData,
124 Encoder<float>& outData);
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +0000125
126} //namespace armnn