blob: 41435f47d26bb57378b1235983947b01bb49f605 [file] [log] [blame]
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "Pad.hpp"
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00007#include "backendsCommon/WorkloadData.hpp"
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +01008#include "TensorBufferArrayView.hpp"
David Monahan34757812019-06-19 11:47:21 +01009#include "Encoders.hpp"
10
11#include <boost/numeric/conversion/cast.hpp>
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +010012#include <cmath>
13#include <cstddef>
14#include <functional>
15#include <limits>
16#include <cassert>
17
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +010018namespace armnn
19{
David Monahan34757812019-06-19 11:47:21 +010020
21template <typename T>
22T ConvertToDataType(const float& value,
23 const armnn::TensorInfo& tensorInfo)
24{
25 std::vector<T> output(1);
26 std::unique_ptr<armnn::Encoder<float>> pEncoder = armnn::MakeEncoder<float>(tensorInfo, output.data());
27 armnn::Encoder<float>& rEncoder = *pEncoder;
28 rEncoder.Set(value);
29 return output[0];
30}
31
Mohamed Nour Abouelseouddd6acea2018-10-18 12:26:19 +010032template <typename T>
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +010033void Pad(const TensorInfo& inputInfo,
34 const TensorInfo& outputInfo,
David Monahan34757812019-06-19 11:47:21 +010035 std::vector<std::pair<unsigned int, unsigned int>> m_padList,
Mohamed Nour Abouelseouddd6acea2018-10-18 12:26:19 +010036 const T* inputData,
David Monahan34757812019-06-19 11:47:21 +010037 T* outData,
38 const float padValue)
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +010039{
40 unsigned int numOutputElements = outputInfo.GetNumElements();
41
42 TensorShape outputShape = outputInfo.GetShape();
43 TensorShape inputShape = inputInfo.GetShape();
44
45 unsigned int numInputDimensions = inputShape.GetNumDimensions();
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +010046
Mohamed Nour Abouelseouddd6acea2018-10-18 12:26:19 +010047 #ifndef NDEBUG
48
49 unsigned int numOutputDimensions = outputShape.GetNumDimensions();
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +010050 assert(numInputDimensions == numOutputDimensions);
Mohamed Nour Abouelseouddd6acea2018-10-18 12:26:19 +010051
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +010052 #endif
53
54 unsigned int inputBatches = 0;
55 unsigned int inputChannels = 0;
56 unsigned int inputHeight = 0;
57 unsigned int inputWidth = 0;
58
59 unsigned int outputChannels = 0;
60 unsigned int outputHeight = 0;
61 unsigned int outputWidth = 0;
62
David Monahan34757812019-06-19 11:47:21 +010063 T convertedPadValue = ConvertToDataType<T>(padValue, inputInfo);
64
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +010065 for (unsigned int i = 0; i < numOutputElements; ++i)
66 {
David Monahan34757812019-06-19 11:47:21 +010067 outData[i] = convertedPadValue;
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +010068 }
69
70 switch(numInputDimensions) {
Mohamed Nour Abouelseouddd6acea2018-10-18 12:26:19 +010071
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +010072 case 1:
73
74 inputWidth = inputShape[0];
75
76 for (unsigned int w = 0; w < inputWidth ; w++)
77 {
David Monahan34757812019-06-19 11:47:21 +010078 outData[w+std::get<0>(m_padList[0])] = inputData[w];
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +010079 }
80
81 break;
Mohamed Nour Abouelseouddd6acea2018-10-18 12:26:19 +010082
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +010083 case 2 :
84
85 inputHeight = inputShape[0];
86 inputWidth = inputShape[1];
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +010087 outputHeight = outputShape[0];
88 outputWidth = outputShape[1];
89
90 for (unsigned int h = 0; h < inputHeight; h++)
91 {
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +010092 for (unsigned int w = 0; w < inputWidth ; w++)
93 {
David Monahan34757812019-06-19 11:47:21 +010094 outData[(h+std::get<0>(m_padList[0]))*outputWidth
95 + (w+std::get<0>(m_padList[1]))] = inputData[h * inputWidth + w];
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +010096 }
97 }
98
99 break;
Mohamed Nour Abouelseouddd6acea2018-10-18 12:26:19 +0100100
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +0100101 case 3 :
102
103 inputChannels = inputShape[0];
104 inputHeight = inputShape[1];
105 inputWidth = inputShape[2];
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +0100106 outputChannels = outputShape[0];
107 outputHeight = outputShape[1];
108 outputWidth = outputShape[2];
109
110 for (unsigned int c = 0; c < inputChannels; c++)
111 {
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +0100112 for (unsigned int h = 0; h < inputHeight; h++)
113 {
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +0100114 for (unsigned int w = 0; w < inputWidth ; w++)
115 {
David Monahan34757812019-06-19 11:47:21 +0100116 outData[(c+std::get<0>(m_padList[0]))*outputHeight*outputWidth
117 + (h+std::get<0>(m_padList[1]))*outputWidth
118 + (w+std::get<0>(m_padList[2]))] = inputData[c * inputHeight * inputWidth
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +0100119 + h * inputWidth
120 + w];
121 }
122 }
123 }
124
125 break;
Mohamed Nour Abouelseouddd6acea2018-10-18 12:26:19 +0100126
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +0100127 case 4 :
128
129 inputBatches = inputShape[0];
130 inputChannels = inputShape[1];
131 inputHeight = inputShape[2];
132 inputWidth = inputShape[3];
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +0100133 outputChannels = outputShape[1];
134 outputHeight = outputShape[2];
135 outputWidth = outputShape[3];
136
137 for (unsigned int b = 0; b < inputBatches; b++)
138 {
139 for (unsigned int c = 0; c < inputChannels; c++)
140 {
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +0100141 for (unsigned int h = 0; h < inputHeight; h++)
142 {
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +0100143 for (unsigned int w = 0; w < inputWidth ; w++)
144 {
David Monahan34757812019-06-19 11:47:21 +0100145 outData[(b+std::get<0>(m_padList[0])) * outputChannels * outputHeight * outputWidth
146 + (c+std::get<0>(m_padList[1])) * outputHeight * outputWidth
147 + (h+std::get<0>(m_padList[2])) * outputWidth
148 + (w+std::get<0>(m_padList[3]))] = inputData[b * inputChannels * inputHeight
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +0100149 * inputWidth
150 + c * inputHeight * inputWidth
151 + h * inputWidth
152 + w];
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +0100153 }
154 }
155 }
156 }
157
158 break;
159
160 default :
Mohamed Nour Abouelseouddd6acea2018-10-18 12:26:19 +0100161
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +0100162 break;
163 }
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +0100164}
165
Mohamed Nour Abouelseouddd6acea2018-10-18 12:26:19 +0100166template void Pad<float>(const TensorInfo& inputInfo,
167 const TensorInfo& outputInfo,
168 std::vector<std::pair<unsigned int, unsigned int>> m_PadList,
169 const float* inputData,
David Monahan34757812019-06-19 11:47:21 +0100170 float* outData,
171 const float padValue);
Mohamed Nour Abouelseouddd6acea2018-10-18 12:26:19 +0100172template void Pad<uint8_t>(const TensorInfo& inputInfo,
173 const TensorInfo& outputInfo,
174 std::vector<std::pair<unsigned int, unsigned int>> m_PadList,
175 const uint8_t* inputData,
David Monahan34757812019-06-19 11:47:21 +0100176 uint8_t* outData,
177 const float padValue);
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +0100178template void Pad<int16_t>(const TensorInfo& inputInfo,
179 const TensorInfo& outputInfo,
180 std::vector<std::pair<unsigned int, unsigned int>> m_PadList,
181 const int16_t* inputData,
182 int16_t* outData,
183 const float padValue);
Mohamed Nour Abouelseouddd6acea2018-10-18 12:26:19 +0100184
185} //namespace armnn