blob: 1f8b674c3ade903186a0967ccef88f37cae1d7db [file] [log] [blame]
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "Pad.hpp"
Sadik Armagan041b3c02020-06-04 10:32:18 +01007
8#include "BaseIterator.hpp"
9#include "Decoders.hpp"
David Monahan34757812019-06-19 11:47:21 +010010#include "Encoders.hpp"
11
Sadik Armagan041b3c02020-06-04 10:32:18 +010012#include <armnnUtils/TensorUtils.hpp>
13
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +010014#include <cmath>
15#include <cstddef>
16#include <functional>
17#include <limits>
18#include <cassert>
19
Sadik Armagan041b3c02020-06-04 10:32:18 +010020namespace
21{
22
23void FillOutputWithPadValue(armnn::Encoder<float>& outputData,
24 const float padValue,
25 const unsigned int numOutputElements)
26{
27 for (unsigned int i = 0; i < numOutputElements; ++i)
28 {
29 outputData[i];
30 outputData.Set(padValue);
31 }
32}
33
34} // anonymous namespace
35
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +010036namespace armnn
37{
David Monahan34757812019-06-19 11:47:21 +010038
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +010039void Pad(const TensorInfo& inputInfo,
40 const TensorInfo& outputInfo,
Sadik Armagan041b3c02020-06-04 10:32:18 +010041 const PadQueueDescriptor& data)
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +010042{
Sadik Armagan041b3c02020-06-04 10:32:18 +010043 auto padList = data.m_Parameters.m_PadList;
44 auto padValue = data.m_Parameters.m_PadValue;
45
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +010046 unsigned int numOutputElements = outputInfo.GetNumElements();
47
48 TensorShape outputShape = outputInfo.GetShape();
Sadik Armagan041b3c02020-06-04 10:32:18 +010049 TensorShape inputShape = inputInfo.GetShape();
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +010050
51 unsigned int numInputDimensions = inputShape.GetNumDimensions();
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +010052
Sadik Armagan041b3c02020-06-04 10:32:18 +010053#ifndef NDEBUG
Mohamed Nour Abouelseouddd6acea2018-10-18 12:26:19 +010054
55 unsigned int numOutputDimensions = outputShape.GetNumDimensions();
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +010056 assert(numInputDimensions == numOutputDimensions);
Mohamed Nour Abouelseouddd6acea2018-10-18 12:26:19 +010057
Sadik Armagan041b3c02020-06-04 10:32:18 +010058#endif
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +010059
Sadik Armagan041b3c02020-06-04 10:32:18 +010060 unsigned int inputBatches = 0;
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +010061 unsigned int inputChannels = 0;
Sadik Armagan041b3c02020-06-04 10:32:18 +010062 unsigned int inputHeight = 0;
63 unsigned int inputWidth = 0;
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +010064
65 unsigned int outputChannels = 0;
Sadik Armagan041b3c02020-06-04 10:32:18 +010066 unsigned int outputHeight = 0;
67 unsigned int outputWidth = 0;
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +010068
Sadik Armagan041b3c02020-06-04 10:32:18 +010069 auto inputData = MakeDecoder<float>(inputInfo, data.m_Inputs[0]->Map());
70 auto outData = MakeEncoder<float>(outputInfo, data.m_Outputs[0]->Map());
David Monahan34757812019-06-19 11:47:21 +010071
Sadik Armagan041b3c02020-06-04 10:32:18 +010072 // Fill the output tensor with Pad value first
73 if (outputInfo.IsQuantized())
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +010074 {
Sadik Armagan041b3c02020-06-04 10:32:18 +010075 // For Quantized types Pad Value should not be quantized with scale and offset of the tensor info
76 auto temporaryInfo = TensorInfo(outputInfo.GetShape(), outputInfo.GetDataType(), 1.0f, 0);
77 auto outputData = MakeEncoder<float>(temporaryInfo, data.m_Outputs[0]->Map());
78 FillOutputWithPadValue(*outputData, padValue, numOutputElements);
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +010079 }
Sadik Armagan041b3c02020-06-04 10:32:18 +010080 else
81 {
82 FillOutputWithPadValue(*outData, padValue, numOutputElements);
83 }
84
85 Decoder<float>& input = *inputData;
86 Encoder<float>& output = *outData;
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +010087
88 switch(numInputDimensions) {
Mohamed Nour Abouelseouddd6acea2018-10-18 12:26:19 +010089
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +010090 case 1:
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +010091 inputWidth = inputShape[0];
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +010092 for (unsigned int w = 0; w < inputWidth ; w++)
93 {
Sadik Armagan041b3c02020-06-04 10:32:18 +010094 input[w];
95 auto inputValue = input.Get();
96 auto outputIndex = w + std::get<0>(padList[0]);
97 output[outputIndex];
98 output.Set(inputValue);
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +010099 }
100
101 break;
102 case 2 :
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +0100103 inputHeight = inputShape[0];
Sadik Armagan041b3c02020-06-04 10:32:18 +0100104 inputWidth = inputShape[1];
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +0100105 outputWidth = outputShape[1];
106
107 for (unsigned int h = 0; h < inputHeight; h++)
108 {
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +0100109 for (unsigned int w = 0; w < inputWidth ; w++)
110 {
Sadik Armagan041b3c02020-06-04 10:32:18 +0100111 input[h * inputWidth + w];
112 auto inputValue = input.Get();
113 auto outputIndex = (h + std::get<0>(padList[0])) * outputWidth + (w + std::get<0>(padList[1]));
114 output[outputIndex];
115 output.Set(inputValue);
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +0100116 }
117 }
118
119 break;
120 case 3 :
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +0100121 inputChannels = inputShape[0];
Sadik Armagan041b3c02020-06-04 10:32:18 +0100122 inputHeight = inputShape[1];
123 inputWidth = inputShape[2];
124 outputHeight = outputShape[1];
125 outputWidth = outputShape[2];
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +0100126
127 for (unsigned int c = 0; c < inputChannels; c++)
128 {
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +0100129 for (unsigned int h = 0; h < inputHeight; h++)
130 {
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +0100131 for (unsigned int w = 0; w < inputWidth ; w++)
132 {
Sadik Armagan041b3c02020-06-04 10:32:18 +0100133 input[c * inputHeight * inputWidth + h * inputWidth + w];
134 auto inputValue = input.Get();
135 auto outputIndex = (c + std::get<0>(padList[0])) * outputHeight * outputWidth
136 + (h + std::get<0>(padList[1])) * outputWidth
137 + (w + std::get<0>(padList[2]));
138 output[outputIndex];
139 output.Set(inputValue);
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +0100140 }
141 }
142 }
143
144 break;
145 case 4 :
Sadik Armagan041b3c02020-06-04 10:32:18 +0100146 inputBatches = inputShape[0];
147 inputChannels = inputShape[1];
148 inputHeight = inputShape[2];
149 inputWidth = inputShape[3];
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +0100150 outputChannels = outputShape[1];
Sadik Armagan041b3c02020-06-04 10:32:18 +0100151 outputHeight = outputShape[2];
152 outputWidth = outputShape[3];
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +0100153
154 for (unsigned int b = 0; b < inputBatches; b++)
155 {
156 for (unsigned int c = 0; c < inputChannels; c++)
157 {
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +0100158 for (unsigned int h = 0; h < inputHeight; h++)
159 {
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +0100160 for (unsigned int w = 0; w < inputWidth ; w++)
161 {
Sadik Armagan041b3c02020-06-04 10:32:18 +0100162 input[b * inputChannels * inputHeight * inputWidth
163 + c * inputHeight * inputWidth
164 + h * inputWidth
165 + w];
166 auto inputValue = input.Get();
167 auto outputIndex = (b + std::get<0>(padList[0]))
168 * outputChannels * outputHeight * outputWidth
169 + (c + std::get<0>(padList[1])) * outputHeight * outputWidth
170 + (h + std::get<0>(padList[2])) * outputWidth
171 + (w + std::get<0>(padList[3]));
172 output[outputIndex];
173 output.Set(inputValue);
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +0100174 }
175 }
176 }
177 }
178
179 break;
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +0100180 default :
181 break;
182 }
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +0100183}
184
Mohamed Nour Abouelseouddd6acea2018-10-18 12:26:19 +0100185} //namespace armnn