blob: 7a928a13361f55085a1fba907bf239885f5b559c [file] [log] [blame]
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "Pad.hpp"
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00007#include "backendsCommon/WorkloadData.hpp"
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +01008#include <boost/numeric/conversion/cast.hpp>
9#include "TensorBufferArrayView.hpp"
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +010010#include <cmath>
11#include <cstddef>
12#include <functional>
13#include <limits>
14#include <cassert>
15
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +010016namespace armnn
17{
Mohamed Nour Abouelseouddd6acea2018-10-18 12:26:19 +010018template <typename T>
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +010019void Pad(const TensorInfo& inputInfo,
20 const TensorInfo& outputInfo,
21 std::vector<std::pair<unsigned int, unsigned int>> m_PadList,
Mohamed Nour Abouelseouddd6acea2018-10-18 12:26:19 +010022 const T* inputData,
23 T* outData)
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +010024{
25 unsigned int numOutputElements = outputInfo.GetNumElements();
26
27 TensorShape outputShape = outputInfo.GetShape();
28 TensorShape inputShape = inputInfo.GetShape();
29
30 unsigned int numInputDimensions = inputShape.GetNumDimensions();
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +010031
Mohamed Nour Abouelseouddd6acea2018-10-18 12:26:19 +010032 #ifndef NDEBUG
33
34 unsigned int numOutputDimensions = outputShape.GetNumDimensions();
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +010035 assert(numInputDimensions == numOutputDimensions);
Mohamed Nour Abouelseouddd6acea2018-10-18 12:26:19 +010036
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +010037 #endif
38
39 unsigned int inputBatches = 0;
40 unsigned int inputChannels = 0;
41 unsigned int inputHeight = 0;
42 unsigned int inputWidth = 0;
43
44 unsigned int outputChannels = 0;
45 unsigned int outputHeight = 0;
46 unsigned int outputWidth = 0;
47
48 for (unsigned int i = 0; i < numOutputElements; ++i)
49 {
50 outData[i] = 0;
51 }
52
53 switch(numInputDimensions) {
Mohamed Nour Abouelseouddd6acea2018-10-18 12:26:19 +010054
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +010055 case 1:
56
57 inputWidth = inputShape[0];
58
59 for (unsigned int w = 0; w < inputWidth ; w++)
60 {
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +010061 outData[w+std::get<0>(m_PadList[0])] = inputData[w];
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +010062 }
63
64 break;
Mohamed Nour Abouelseouddd6acea2018-10-18 12:26:19 +010065
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +010066 case 2 :
67
68 inputHeight = inputShape[0];
69 inputWidth = inputShape[1];
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +010070 outputHeight = outputShape[0];
71 outputWidth = outputShape[1];
72
73 for (unsigned int h = 0; h < inputHeight; h++)
74 {
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +010075 for (unsigned int w = 0; w < inputWidth ; w++)
76 {
77 outData[(h+std::get<0>(m_PadList[0]))*outputWidth
78 + (w+std::get<0>(m_PadList[1]))] = inputData[h * inputWidth + w];
79 }
80 }
81
82 break;
Mohamed Nour Abouelseouddd6acea2018-10-18 12:26:19 +010083
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +010084 case 3 :
85
86 inputChannels = inputShape[0];
87 inputHeight = inputShape[1];
88 inputWidth = inputShape[2];
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +010089 outputChannels = outputShape[0];
90 outputHeight = outputShape[1];
91 outputWidth = outputShape[2];
92
93 for (unsigned int c = 0; c < inputChannels; c++)
94 {
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +010095 for (unsigned int h = 0; h < inputHeight; h++)
96 {
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +010097 for (unsigned int w = 0; w < inputWidth ; w++)
98 {
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +010099 outData[(c+std::get<0>(m_PadList[0]))*outputHeight*outputWidth
100 + (h+std::get<0>(m_PadList[1]))*outputWidth
101 + (w+std::get<0>(m_PadList[2]))] = inputData[c * inputHeight * inputWidth
102 + h * inputWidth
103 + w];
104 }
105 }
106 }
107
108 break;
Mohamed Nour Abouelseouddd6acea2018-10-18 12:26:19 +0100109
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +0100110 case 4 :
111
112 inputBatches = inputShape[0];
113 inputChannels = inputShape[1];
114 inputHeight = inputShape[2];
115 inputWidth = inputShape[3];
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +0100116 outputChannels = outputShape[1];
117 outputHeight = outputShape[2];
118 outputWidth = outputShape[3];
119
120 for (unsigned int b = 0; b < inputBatches; b++)
121 {
122 for (unsigned int c = 0; c < inputChannels; c++)
123 {
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +0100124 for (unsigned int h = 0; h < inputHeight; h++)
125 {
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +0100126 for (unsigned int w = 0; w < inputWidth ; w++)
127 {
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +0100128 outData[(b+std::get<0>(m_PadList[0])) * outputChannels * outputHeight * outputWidth
129 + (c+std::get<0>(m_PadList[1])) * outputHeight * outputWidth
130 + (h+std::get<0>(m_PadList[2])) * outputWidth
131 + (w+std::get<0>(m_PadList[3]))] = inputData[b * inputChannels * inputHeight
132 * inputWidth
133 + c * inputHeight * inputWidth
134 + h * inputWidth
135 + w];
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +0100136 }
137 }
138 }
139 }
140
141 break;
142
143 default :
Mohamed Nour Abouelseouddd6acea2018-10-18 12:26:19 +0100144
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +0100145 break;
146 }
Mohamed Nour Abouelseoud7420e552018-10-12 12:26:24 +0100147}
148
Mohamed Nour Abouelseouddd6acea2018-10-18 12:26:19 +0100149template void Pad<float>(const TensorInfo& inputInfo,
150 const TensorInfo& outputInfo,
151 std::vector<std::pair<unsigned int, unsigned int>> m_PadList,
152 const float* inputData,
153 float* outData);
154template void Pad<uint8_t>(const TensorInfo& inputInfo,
155 const TensorInfo& outputInfo,
156 std::vector<std::pair<unsigned int, unsigned int>> m_PadList,
157 const uint8_t* inputData,
158 uint8_t* outData);
159
160} //namespace armnn