blob: 56980141811fee97267faf85fd45c958e77b690d [file] [log] [blame]
Aron Virginas-Tar735a4502019-06-26 15:02:47 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "TransposeConvolution2d.hpp"
7
Matteo Martincighe011d202019-11-28 11:35:47 +00008#include <armnnUtils/DataLayoutIndexed.hpp>
Aron Virginas-Tar735a4502019-06-26 15:02:47 +01009
10namespace armnn
11{
12
13using namespace armnnUtils;
14
Aron Virginas-Tar735a4502019-06-26 15:02:47 +010015void TransposeConvolution2dImpl(const TransposeConvolution2dDescriptor& descriptor,
16 const TensorShape& inputShape,
17 Decoder<float>& inputDecoder,
18 const TensorShape& outputShape,
19 Encoder<float>& outputEncoder,
20 const TensorShape& weightsShape,
21 Decoder<float>& weightsDecoder,
22 Decoder<float>* biasesDecoder)
23{
24 if (descriptor.m_BiasEnabled && !biasesDecoder)
25 {
26 throw InvalidArgumentException("Biases enabled but no bias data provided");
27 }
Aron Virginas-Tar735a4502019-06-26 15:02:47 +010028 const DataLayoutIndexed dataLayoutIndexed(descriptor.m_DataLayout);
Mike Kellya24d9c72019-08-13 10:06:25 +010029 const unsigned int channelsIndex = dataLayoutIndexed.GetChannelsIndex();
30 const unsigned int heightIndex = dataLayoutIndexed.GetHeightIndex();
31 const unsigned int widthIndex = dataLayoutIndexed.GetWidthIndex();
Aron Virginas-Tar735a4502019-06-26 15:02:47 +010032
Mike Kellya24d9c72019-08-13 10:06:25 +010033 unsigned int numBatches = inputShape[0];
Aron Virginas-Tar735a4502019-06-26 15:02:47 +010034
Mike Kellya24d9c72019-08-13 10:06:25 +010035 unsigned int inputWidth = inputShape[widthIndex];
36 unsigned int inputHeight = inputShape[heightIndex];
37 unsigned int inputDepth = inputShape[channelsIndex];
Aron Virginas-Tar735a4502019-06-26 15:02:47 +010038
Mike Kellya24d9c72019-08-13 10:06:25 +010039 unsigned int weightsHeight = weightsShape[heightIndex];
40 unsigned int weightsWidth = weightsShape[widthIndex];
Aron Virginas-Tar735a4502019-06-26 15:02:47 +010041
Mike Kellya24d9c72019-08-13 10:06:25 +010042 unsigned int outputHeight = outputShape[heightIndex];
43 unsigned int outputWidth = outputShape[widthIndex];
44 unsigned int outputDepth = outputShape[channelsIndex];
Aron Virginas-Tar735a4502019-06-26 15:02:47 +010045
Mike Kellya24d9c72019-08-13 10:06:25 +010046 unsigned int paddingLeft = descriptor.m_PadLeft;
47 unsigned int paddingTop = descriptor.m_PadTop;
Aron Virginas-Tar735a4502019-06-26 15:02:47 +010048
Mike Kellya24d9c72019-08-13 10:06:25 +010049 unsigned int strideX = descriptor.m_StrideX;
50 unsigned int strideY = descriptor.m_StrideY;
Aron Virginas-Tar735a4502019-06-26 15:02:47 +010051
Mike Kelly675fa4f2019-08-21 09:36:58 +010052 std::vector<float> outputBuffer(outputShape.GetNumElements(), 0);
Mike Kellya24d9c72019-08-13 10:06:25 +010053
54 for (unsigned int batch = 0u; batch < numBatches; ++batch)
55 {
56 for (unsigned int yInput = 0u; yInput < inputHeight; ++yInput)
Aron Virginas-Tar735a4502019-06-26 15:02:47 +010057 {
Mike Kellya24d9c72019-08-13 10:06:25 +010058 for (unsigned int xInput = 0u; xInput < inputWidth; ++xInput)
Aron Virginas-Tar735a4502019-06-26 15:02:47 +010059 {
Mike Kellya24d9c72019-08-13 10:06:25 +010060 unsigned int xOutputOrigin = xInput * strideX - paddingLeft;
61 unsigned int yOutputOrigin = yInput * strideY - paddingTop;
Aron Virginas-Tar735a4502019-06-26 15:02:47 +010062
Mike Kellya24d9c72019-08-13 10:06:25 +010063 for (unsigned int dOutput = 0u; dOutput < outputDepth; ++dOutput)
Aron Virginas-Tar735a4502019-06-26 15:02:47 +010064 {
Mike Kellya24d9c72019-08-13 10:06:25 +010065 for (unsigned int yWeights = 0u; yWeights < weightsHeight; ++yWeights)
Aron Virginas-Tar735a4502019-06-26 15:02:47 +010066 {
Aron Virginas-Tar0558ca42019-08-14 11:39:50 +010067 for (unsigned int xWeights = 0u; xWeights < weightsWidth; ++xWeights)
Mike Kellya24d9c72019-08-13 10:06:25 +010068 {
69 unsigned int yOutput = yOutputOrigin + yWeights;
70 unsigned int xOutput = xOutputOrigin + xWeights;
71
72 if (yOutput < outputHeight && xOutput< outputWidth)
73 {
74 for (unsigned int dInput = 0u; dInput < inputDepth; dInput++)
75 {
76 const unsigned int inputIndex =
77 dataLayoutIndexed.GetIndex(inputShape, batch, dInput, yInput, xInput);
78 inputDecoder[inputIndex];
79
80 const unsigned int weightsIndex =
Aron Virginas-Taraec942c2019-08-14 14:37:42 +010081 dataLayoutIndexed.GetIndex(weightsShape, dOutput, dInput, yWeights, xWeights);
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +000082 weightsDecoder.SetIndex(weightsIndex, dOutput);
Mike Kellya24d9c72019-08-13 10:06:25 +010083
84 const unsigned int outputIndex =
85 dataLayoutIndexed.GetIndex(outputShape, batch, dOutput, yOutput, xOutput);
86 outputEncoder[outputIndex];
87
Mike Kelly675fa4f2019-08-21 09:36:58 +010088 float output = outputBuffer[outputIndex];
Mike Kellya24d9c72019-08-13 10:06:25 +010089 output += inputDecoder.Get() * weightsDecoder.Get();
Mike Kelly675fa4f2019-08-21 09:36:58 +010090 outputBuffer[outputIndex] = output;
Mike Kellya24d9c72019-08-13 10:06:25 +010091 }
92 }
93 }
94 }
95 }
96 }
97 }
98 }
99
100 // Apply bias (if enabled)
101 if (descriptor.m_BiasEnabled)
102 {
103 outputEncoder[0];
104 Decoder<float>& rBiasesDecoder = *biasesDecoder;
105
106 for (unsigned int batch = 0u; batch < numBatches; ++batch)
107 {
108 for (unsigned int dOutput = 0u; dOutput < outputDepth; ++dOutput)
109 {
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +0000110 rBiasesDecoder.SetIndex(dOutput, dOutput);
Mike Kellya24d9c72019-08-13 10:06:25 +0100111 for (unsigned int yOutput = 0u; yOutput < outputHeight; ++yOutput)
112 {
113 for (unsigned int xOutput = 0u; xOutput < outputWidth; ++xOutput)
114 {
115 const unsigned int outputIndex =
116 dataLayoutIndexed.GetIndex(outputShape, batch, dOutput, yOutput, xOutput);
Mike Kelly675fa4f2019-08-21 09:36:58 +0100117 outputBuffer[outputIndex] += rBiasesDecoder.Get();
Aron Virginas-Tar735a4502019-06-26 15:02:47 +0100118 }
119 }
120 }
121 }
122 }
Mike Kelly675fa4f2019-08-21 09:36:58 +0100123 outputEncoder[0];
124 for (float output : outputBuffer)
125 {
126 outputEncoder.Set(output);
127 ++outputEncoder;
128 }
Aron Virginas-Tar735a4502019-06-26 15:02:47 +0100129}
130
Mike Kellya24d9c72019-08-13 10:06:25 +0100131} // namespace armnn