blob: db15cefe103ca4f95e81092b2420c13b8c5bc7ea [file] [log] [blame]
Aron Virginas-Tar735a4502019-06-26 15:02:47 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "TransposeConvolution2d.hpp"
7
8#include <DataLayoutIndexed.hpp>
9
10namespace armnn
11{
12
13using namespace armnnUtils;
14
15struct TensorData
16{
17 TensorShape shape;
18 std::vector<float> data;
19};
20
21TensorData SetUpStridedInput(const TensorShape& inputShape,
22 Decoder<float>& inputDecoder,
23 const TransposeConvolution2dDescriptor& descriptor,
24 const DataLayoutIndexed& dataLayoutIndexed)
25{
26 const unsigned int cIndex = dataLayoutIndexed.GetChannelsIndex();
27 const unsigned int hIndex = dataLayoutIndexed.GetHeightIndex();
28 const unsigned int wIndex = dataLayoutIndexed.GetWidthIndex();
29
30 const unsigned int batches = inputShape[0];
31 const unsigned int channels = inputShape[cIndex];
32
33 const unsigned int wInput = inputShape[wIndex];
34 const unsigned int hInput = inputShape[hIndex];
35
36 const unsigned int wStridedInput = 1u + descriptor.m_StrideX * (wInput - 1);
37 const unsigned int hStridedInput = 1u + descriptor.m_StrideY * (hInput - 1);
38
39 TensorData stridedInput;
40 stridedInput.data = std::vector<float>(batches * channels * wStridedInput * hStridedInput, 0.0f);
41 stridedInput.shape = TensorShape(4);
42
43 stridedInput.shape[0] = batches;
44 stridedInput.shape[cIndex] = channels;
45 stridedInput.shape[hIndex] = hStridedInput;
46 stridedInput.shape[wIndex] = wStridedInput;
47
48 // expand input data with strides
49 for (unsigned int batchIdx = 0u; batchIdx < batches; ++batchIdx)
50 {
51 for (unsigned int cInput = 0u; cInput < channels; ++cInput)
52 {
53 for (unsigned int yInput = 0u, yStrided = 0u;
54 yInput < hInput && yStrided < hStridedInput;
55 ++yInput, yStrided += descriptor.m_StrideY)
56 {
57 for (unsigned int xInput = 0u, xStrided = 0u;
58 xInput < wInput && xStrided < wStridedInput;
59 ++xInput, xStrided += descriptor.m_StrideX)
60 {
61 unsigned int inputIdx =
62 dataLayoutIndexed.GetIndex(inputShape, batchIdx, cInput, yInput, xInput);
63 unsigned int stridedInputIdx =
64 dataLayoutIndexed.GetIndex(stridedInput.shape, batchIdx, cInput, yStrided, xStrided);
65
66 inputDecoder[inputIdx];
67 stridedInput.data[stridedInputIdx] = inputDecoder.Get();
68 }
69 }
70 }
71 }
72
73 return stridedInput;
74}
75
76TensorData SetUpEmptyPaddedOutput(const TensorShape& outputShape,
77 const TransposeConvolution2dDescriptor& descriptor,
78 const DataLayoutIndexed& dataLayoutIndexed)
79{
80 const unsigned int cIndex = dataLayoutIndexed.GetChannelsIndex();
81 const unsigned int hIndex = dataLayoutIndexed.GetHeightIndex();
82 const unsigned int wIndex = dataLayoutIndexed.GetWidthIndex();
83
84 const unsigned int batches = outputShape[0];
85 const unsigned int channels = outputShape[cIndex];
86
87 const unsigned int wOutput = outputShape[wIndex];
88 const unsigned int hOutput = outputShape[hIndex];
89
90 const unsigned int wPaddedOutput = wOutput + descriptor.m_PadLeft + descriptor.m_PadRight;
91 const unsigned int hPaddedOutput = hOutput + descriptor.m_PadTop + descriptor.m_PadBottom;
92
93 TensorData paddedOutput;
94 paddedOutput.data = std::vector<float>(batches * channels * wPaddedOutput * hPaddedOutput, 0.0f);
95 paddedOutput.shape = TensorShape(4);
96
97 paddedOutput.shape[0] = batches;
98 paddedOutput.shape[cIndex] = channels;
99 paddedOutput.shape[hIndex] = hPaddedOutput;
100 paddedOutput.shape[wIndex] = wPaddedOutput;
101
102 return paddedOutput;
103}
104
105void Deconvolve(const TensorData& stridedInput,
106 TensorData& paddedOutput,
107 const TensorShape& weightsShape,
108 Decoder<float>& weightsDecoder,
109 const DataLayoutIndexed& dataLayoutIndexed)
110{
111 const unsigned int cIndex = dataLayoutIndexed.GetChannelsIndex();
112 const unsigned int hIndex = dataLayoutIndexed.GetHeightIndex();
113 const unsigned int wIndex = dataLayoutIndexed.GetWidthIndex();
114
115 const unsigned int batches = stridedInput.shape[0];
116 const unsigned int channels = stridedInput.shape[cIndex];
117
118 const unsigned int wKernel = weightsShape[wIndex];
119 const unsigned int hKernel = weightsShape[hIndex];
120
121 const unsigned int wStridedInput = stridedInput.shape[wIndex];
122 const unsigned int hStridedInput = stridedInput.shape[hIndex];
123
124 // loop through all input elements
125 for (unsigned int batchIdx = 0u; batchIdx < batches; ++batchIdx)
126 {
127 for (unsigned int cInput = 0u; cInput < channels; ++cInput)
128 {
129 for (unsigned int yInput = 0u; yInput < hStridedInput; ++yInput)
130 {
131 for (unsigned int xInput = 0u; xInput < wStridedInput; ++xInput)
132 {
133 // obtain input value
134 unsigned int inputIdx =
135 dataLayoutIndexed.GetIndex(stridedInput.shape, batchIdx, cInput, yInput, xInput);
136 float inputValue = stridedInput.data[inputIdx];
137
138 // loop through kernel
139 for (unsigned int yKernel = 0u; yKernel < hKernel; ++yKernel)
140 {
141 for (unsigned int xKernel = 0; xKernel < wKernel; ++xKernel)
142 {
143 unsigned int kernelIdx =
144 dataLayoutIndexed.GetIndex(weightsShape, batchIdx, cInput, yKernel, xKernel);
145
146 weightsDecoder[kernelIdx];
147 float kernelValue = weightsDecoder.Get();
148
149 unsigned int xOutput = xInput + xKernel;
150 unsigned int yOutput = yInput + yKernel;
151
152 // compute output increment
153 float outputValue = inputValue * kernelValue;
154
155 unsigned int outputIdx = dataLayoutIndexed.GetIndex(paddedOutput.shape,
156 batchIdx,
157 cInput,
158 yOutput,
159 xOutput);
160
161 // set output value
162 paddedOutput.data[outputIdx] += outputValue;
163 }
164 }
165 }
166 }
167 }
168 }
169}
170
171void TransposeConvolution2dImpl(const TransposeConvolution2dDescriptor& descriptor,
172 const TensorShape& inputShape,
173 Decoder<float>& inputDecoder,
174 const TensorShape& outputShape,
175 Encoder<float>& outputEncoder,
176 const TensorShape& weightsShape,
177 Decoder<float>& weightsDecoder,
178 Decoder<float>* biasesDecoder)
179{
180 if (descriptor.m_BiasEnabled && !biasesDecoder)
181 {
182 throw InvalidArgumentException("Biases enabled but no bias data provided");
183 }
184
185 const DataLayoutIndexed dataLayoutIndexed(descriptor.m_DataLayout);
186
187 const unsigned int cIndex = dataLayoutIndexed.GetChannelsIndex();
188 const unsigned int hIndex = dataLayoutIndexed.GetHeightIndex();
189 const unsigned int wIndex = dataLayoutIndexed.GetWidthIndex();
190
191 const unsigned int numBatches = inputShape[0];
192 const unsigned int numChannels = inputShape[cIndex];
193
194 // set up temporary strided input
195 TensorData stridedInput = SetUpStridedInput(inputShape, inputDecoder, descriptor, dataLayoutIndexed);
196
197 // set up temporary (empty) padded output
198 TensorData paddedOutput = SetUpEmptyPaddedOutput(outputShape, descriptor, dataLayoutIndexed);
199
200 // run deconvolution (without biases) on strided input to produce padded output
201 Deconvolve(stridedInput, paddedOutput, weightsShape, weightsDecoder, dataLayoutIndexed);
202
203 const unsigned int wPaddedOutput = paddedOutput.shape[wIndex];
204 const unsigned int hPaddedOutput = paddedOutput.shape[hIndex];
205
206 // remove padding and apply bias (if enabled)
207 for (unsigned int batchIdx = 0u; batchIdx < numBatches; ++batchIdx)
208 {
209 for (unsigned int cOutput = 0u; cOutput < numChannels; ++cOutput)
210 {
211 // update bias decoder iterator
212 if (descriptor.m_BiasEnabled)
213 {
214 (*biasesDecoder)[cOutput];
215 }
216
217 for (unsigned int yPaddedOutput = descriptor.m_PadTop;
218 yPaddedOutput < (hPaddedOutput - descriptor.m_PadBottom);
219 ++yPaddedOutput)
220 {
221 for (unsigned int xPaddedOutput = descriptor.m_PadLeft;
222 xPaddedOutput < (wPaddedOutput - descriptor.m_PadRight);
223 ++xPaddedOutput)
224 {
225 unsigned int xOutput = xPaddedOutput - descriptor.m_PadLeft;
226 unsigned int yOutput = yPaddedOutput - descriptor.m_PadTop;
227
228 unsigned int outputIdx =
229 dataLayoutIndexed.GetIndex(outputShape, batchIdx, cOutput, yOutput, xOutput);
230 unsigned int paddedOutputIdx =
231 dataLayoutIndexed.GetIndex(paddedOutput.shape, batchIdx, cOutput, yPaddedOutput, xPaddedOutput);
232
233 // encode (copy) output data
234 outputEncoder[outputIdx];
235 outputEncoder.Set(paddedOutput.data[paddedOutputIdx]);
236
237 // apply bias (if enabled)
238 if (descriptor.m_BiasEnabled)
239 {
240 outputEncoder.Set(outputEncoder.Get() + biasesDecoder->Get());
241 }
242 }
243 }
244 }
245 }
246}
247
248} // namespace armnn