blob: a1a6cbae68b572c2c3118e70627754394eeecefd [file] [log] [blame]
Aron Virginas-Tar735a4502019-06-26 15:02:47 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "TransposeConvolution2d.hpp"
7
Matteo Martincighe011d202019-11-28 11:35:47 +00008#include <armnnUtils/DataLayoutIndexed.hpp>
Aron Virginas-Tar735a4502019-06-26 15:02:47 +01009
10namespace armnn
11{
12
13using namespace armnnUtils;
14
Aron Virginas-Tar735a4502019-06-26 15:02:47 +010015void TransposeConvolution2dImpl(const TransposeConvolution2dDescriptor& descriptor,
16 const TensorShape& inputShape,
17 Decoder<float>& inputDecoder,
18 const TensorShape& outputShape,
19 Encoder<float>& outputEncoder,
20 const TensorShape& weightsShape,
21 Decoder<float>& weightsDecoder,
22 Decoder<float>* biasesDecoder)
23{
24 if (descriptor.m_BiasEnabled && !biasesDecoder)
25 {
26 throw InvalidArgumentException("Biases enabled but no bias data provided");
27 }
Aron Virginas-Tar735a4502019-06-26 15:02:47 +010028 const DataLayoutIndexed dataLayoutIndexed(descriptor.m_DataLayout);
Mike Kellya24d9c72019-08-13 10:06:25 +010029 const unsigned int channelsIndex = dataLayoutIndexed.GetChannelsIndex();
30 const unsigned int heightIndex = dataLayoutIndexed.GetHeightIndex();
31 const unsigned int widthIndex = dataLayoutIndexed.GetWidthIndex();
Aron Virginas-Tar735a4502019-06-26 15:02:47 +010032
Finn Williamsb9dcfe62020-09-17 15:58:31 +010033 const unsigned int numBatches = inputShape[0];
Aron Virginas-Tar735a4502019-06-26 15:02:47 +010034
Finn Williamsb9dcfe62020-09-17 15:58:31 +010035 const unsigned int inputWidth = inputShape[widthIndex];
36 const unsigned int inputHeight = inputShape[heightIndex];
37 const unsigned int inputDepth = inputShape[channelsIndex];
Aron Virginas-Tar735a4502019-06-26 15:02:47 +010038
Finn Williamsb9dcfe62020-09-17 15:58:31 +010039 const unsigned int weightsHeight = weightsShape[heightIndex];
40 const unsigned int weightsWidth = weightsShape[widthIndex];
41 const unsigned int weightsDepth = weightsShape[channelsIndex];
Aron Virginas-Tar735a4502019-06-26 15:02:47 +010042
Finn Williamsb9dcfe62020-09-17 15:58:31 +010043 const unsigned int outputHeight = outputShape[heightIndex];
44 const unsigned int outputWidth = outputShape[widthIndex];
45 const unsigned int outputDepth = outputShape[channelsIndex];
Aron Virginas-Tar735a4502019-06-26 15:02:47 +010046
Finn Williamsb9dcfe62020-09-17 15:58:31 +010047 const unsigned int paddingLeft = descriptor.m_PadLeft;
48 const unsigned int paddingTop = descriptor.m_PadTop;
Aron Virginas-Tar735a4502019-06-26 15:02:47 +010049
Finn Williamsb9dcfe62020-09-17 15:58:31 +010050 const unsigned int strideX = descriptor.m_StrideX;
51 const unsigned int strideY = descriptor.m_StrideY;
Aron Virginas-Tar735a4502019-06-26 15:02:47 +010052
Mike Kelly675fa4f2019-08-21 09:36:58 +010053 std::vector<float> outputBuffer(outputShape.GetNumElements(), 0);
Mike Kellya24d9c72019-08-13 10:06:25 +010054
Finn Williamsea8ce702020-09-29 19:54:00 +010055 const std::vector<float> inputVec = inputDecoder.DecodeTensor(inputShape);
56 const std::vector<float> filterVec = weightsDecoder.DecodeTensor(weightsShape);
Finn Williamsb9dcfe62020-09-17 15:58:31 +010057
Mike Kellya24d9c72019-08-13 10:06:25 +010058 for (unsigned int batch = 0u; batch < numBatches; ++batch)
59 {
60 for (unsigned int yInput = 0u; yInput < inputHeight; ++yInput)
Aron Virginas-Tar735a4502019-06-26 15:02:47 +010061 {
Mike Kellya24d9c72019-08-13 10:06:25 +010062 for (unsigned int xInput = 0u; xInput < inputWidth; ++xInput)
Aron Virginas-Tar735a4502019-06-26 15:02:47 +010063 {
Mike Kellya24d9c72019-08-13 10:06:25 +010064 unsigned int xOutputOrigin = xInput * strideX - paddingLeft;
65 unsigned int yOutputOrigin = yInput * strideY - paddingTop;
Aron Virginas-Tar735a4502019-06-26 15:02:47 +010066
Mike Kellya24d9c72019-08-13 10:06:25 +010067 for (unsigned int dOutput = 0u; dOutput < outputDepth; ++dOutput)
Aron Virginas-Tar735a4502019-06-26 15:02:47 +010068 {
Mike Kellya24d9c72019-08-13 10:06:25 +010069 for (unsigned int yWeights = 0u; yWeights < weightsHeight; ++yWeights)
Aron Virginas-Tar735a4502019-06-26 15:02:47 +010070 {
Aron Virginas-Tar0558ca42019-08-14 11:39:50 +010071 for (unsigned int xWeights = 0u; xWeights < weightsWidth; ++xWeights)
Mike Kellya24d9c72019-08-13 10:06:25 +010072 {
73 unsigned int yOutput = yOutputOrigin + yWeights;
74 unsigned int xOutput = xOutputOrigin + xWeights;
75
76 if (yOutput < outputHeight && xOutput< outputWidth)
77 {
78 for (unsigned int dInput = 0u; dInput < inputDepth; dInput++)
79 {
Finn Williamsb9dcfe62020-09-17 15:58:31 +010080 unsigned int inputIndex;
81 unsigned int outputIndex;
82 unsigned int weightsIndex;
Mike Kellya24d9c72019-08-13 10:06:25 +010083
Finn Williamsb9dcfe62020-09-17 15:58:31 +010084 if(descriptor.m_DataLayout == armnn::DataLayout::NHWC)
85 {
86 inputIndex = batch * inputHeight * inputWidth * inputDepth +
87 yInput * inputWidth * inputDepth +
88 xInput * inputDepth +
89 dInput;
Mike Kellya24d9c72019-08-13 10:06:25 +010090
Finn Williamsb9dcfe62020-09-17 15:58:31 +010091 weightsIndex = dOutput * weightsHeight * weightsWidth * weightsDepth +
92 yWeights * weightsWidth * weightsDepth +
93 xWeights * weightsDepth +
94 dInput;
Mike Kellya24d9c72019-08-13 10:06:25 +010095
Finn Williamsb9dcfe62020-09-17 15:58:31 +010096 outputIndex = batch * outputHeight * outputWidth * outputDepth +
97 yOutput * outputWidth * outputDepth +
98 xOutput * outputDepth +
99 dOutput;
100 }
101 else
102 {
103 inputIndex = batch * inputDepth * inputHeight * inputWidth +
104 dInput * inputHeight * inputWidth +
105 yInput * inputWidth +
106 xInput;
107
108 weightsIndex = dOutput * weightsDepth * weightsHeight * weightsWidth +
109 dInput * weightsHeight * weightsWidth +
110 yWeights * weightsWidth +
111 xWeights;
112
113 outputIndex = batch * outputDepth * outputHeight * outputWidth +
114 dOutput * outputHeight * outputWidth +
115 yOutput * outputWidth +
116 xOutput;
117 }
118
119 outputBuffer[outputIndex] += inputVec[inputIndex] * filterVec[weightsIndex];
Mike Kellya24d9c72019-08-13 10:06:25 +0100120 }
121 }
122 }
123 }
Finn Williamsb9dcfe62020-09-17 15:58:31 +0100124
Mike Kellya24d9c72019-08-13 10:06:25 +0100125 }
126 }
127 }
128 }
129
130 // Apply bias (if enabled)
131 if (descriptor.m_BiasEnabled)
132 {
133 outputEncoder[0];
134 Decoder<float>& rBiasesDecoder = *biasesDecoder;
135
136 for (unsigned int batch = 0u; batch < numBatches; ++batch)
137 {
138 for (unsigned int dOutput = 0u; dOutput < outputDepth; ++dOutput)
139 {
Jan Eilers53ef7952021-06-02 12:01:25 +0100140 rBiasesDecoder[dOutput];
Mike Kellya24d9c72019-08-13 10:06:25 +0100141 for (unsigned int yOutput = 0u; yOutput < outputHeight; ++yOutput)
142 {
143 for (unsigned int xOutput = 0u; xOutput < outputWidth; ++xOutput)
144 {
145 const unsigned int outputIndex =
146 dataLayoutIndexed.GetIndex(outputShape, batch, dOutput, yOutput, xOutput);
Mike Kelly675fa4f2019-08-21 09:36:58 +0100147 outputBuffer[outputIndex] += rBiasesDecoder.Get();
Aron Virginas-Tar735a4502019-06-26 15:02:47 +0100148 }
149 }
150 }
151 }
152 }
Mike Kelly675fa4f2019-08-21 09:36:58 +0100153 outputEncoder[0];
154 for (float output : outputBuffer)
155 {
156 outputEncoder.Set(output);
157 ++outputEncoder;
158 }
Aron Virginas-Tar735a4502019-06-26 15:02:47 +0100159}
160
Mike Kellya24d9c72019-08-13 10:06:25 +0100161} // namespace armnn