blob: 407774e05d27142e945f99b042d658b4577e388e [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
Teresa Charlin970f43b2019-07-01 13:51:07 +01002// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "Resize.hpp"
7
8#include "TensorBufferArrayView.hpp"
9
10#include <boost/numeric/conversion/cast.hpp>
11
12#include <cmath>
13#include <algorithm>
14
15using namespace armnnUtils;
16
17namespace armnn
18{
19
20namespace
21{
22
23inline float Lerp(float a, float b, float w)
24{
25 return w * b + (1.f - w) * a;
26}
27
Teresa Charlinda1fb9b2019-07-02 13:25:22 +010028inline double EuclideanDistance(float Xa, float Ya, const unsigned int Xb, const unsigned int Yb)
29{
30 return std::sqrt(pow(Xa - boost::numeric_cast<float>(Xb), 2) + pow(Ya - boost::numeric_cast<float>(Yb), 2));
31}
32
Teresa Charlin970f43b2019-07-01 13:51:07 +010033}// anonymous namespace
34
35void Resize(Decoder<float>& in,
36 const TensorInfo& inputInfo,
37 Encoder<float>& out,
38 const TensorInfo& outputInfo,
39 DataLayoutIndexed dataLayout,
Sang-Hoon Park820eb142020-01-08 10:25:24 +000040 armnn::ResizeMethod resizeMethod,
41 bool alignCorners)
Teresa Charlin970f43b2019-07-01 13:51:07 +010042{
43 // We follow the definition of TensorFlow and AndroidNN: the top-left corner of a texel in the output
44 // image is projected into the input image to figure out the interpolants and weights. Note that this
45 // will yield different results than if projecting the centre of output texels.
46
47 const unsigned int batchSize = inputInfo.GetShape()[0];
48 const unsigned int channelCount = inputInfo.GetShape()[dataLayout.GetChannelsIndex()];
49
50 const unsigned int inputHeight = inputInfo.GetShape()[dataLayout.GetHeightIndex()];
51 const unsigned int inputWidth = inputInfo.GetShape()[dataLayout.GetWidthIndex()];
52 const unsigned int outputHeight = outputInfo.GetShape()[dataLayout.GetHeightIndex()];
53 const unsigned int outputWidth = outputInfo.GetShape()[dataLayout.GetWidthIndex()];
54
Sang-Hoon Park820eb142020-01-08 10:25:24 +000055 const unsigned int sizeOffset = resizeMethod == armnn::ResizeMethod::Bilinear && alignCorners ? 1 : 0;
56
Teresa Charlin970f43b2019-07-01 13:51:07 +010057 // How much to scale pixel coordinates in the output image, to get the corresponding pixel coordinates
58 // in the input image.
Sang-Hoon Park820eb142020-01-08 10:25:24 +000059 const float scaleY = boost::numeric_cast<float>(inputHeight - sizeOffset)
60 / boost::numeric_cast<float>(outputHeight - sizeOffset);
61 const float scaleX = boost::numeric_cast<float>(inputWidth - sizeOffset)
62 / boost::numeric_cast<float>(outputWidth - sizeOffset);
Teresa Charlin970f43b2019-07-01 13:51:07 +010063
64 TensorShape inputShape = inputInfo.GetShape();
65 TensorShape outputShape = outputInfo.GetShape();
66
67 for (unsigned int n = 0; n < batchSize; ++n)
68 {
69 for (unsigned int c = 0; c < channelCount; ++c)
70 {
71 for (unsigned int y = 0; y < outputHeight; ++y)
72 {
73 // Corresponding real-valued height coordinate in input image.
74 const float iy = boost::numeric_cast<float>(y) * scaleY;
75
76 // Discrete height coordinate of top-left texel (in the 2x2 texel area used for interpolation).
77 const float fiy = floorf(iy);
78 const unsigned int y0 = boost::numeric_cast<unsigned int>(fiy);
79
80 // Interpolation weight (range [0,1]).
81 const float yw = iy - fiy;
82
83 for (unsigned int x = 0; x < outputWidth; ++x)
84 {
85 // Real-valued and discrete width coordinates in input image.
86 const float ix = boost::numeric_cast<float>(x) * scaleX;
87 const float fix = floorf(ix);
88 const unsigned int x0 = boost::numeric_cast<unsigned int>(fix);
89
90 // Interpolation weight (range [0,1]).
91 const float xw = ix - fix;
92
93 // Discrete width/height coordinates of texels below and to the right of (x0, y0).
94 const unsigned int x1 = std::min(x0 + 1, inputWidth - 1u);
95 const unsigned int y1 = std::min(y0 + 1, inputHeight - 1u);
96
97 float interpolatedValue;
98 switch (resizeMethod)
99 {
100 case armnn::ResizeMethod::Bilinear:
101 {
102 in[dataLayout.GetIndex(inputShape, n, c, y0, x0)];
103 float input1 = in.Get();
104 in[dataLayout.GetIndex(inputShape, n, c, y0, x1)];
105 float input2 = in.Get();
106 in[dataLayout.GetIndex(inputShape, n, c, y1, x0)];
107 float input3 = in.Get();
108 in[dataLayout.GetIndex(inputShape, n, c, y1, x1)];
109 float input4 = in.Get();
110
111 const float ly0 = Lerp(input1, input2, xw); // lerp along row y0.
112 const float ly1 = Lerp(input3, input4, xw); // lerp along row y1.
113 interpolatedValue = Lerp(ly0, ly1, yw);
114 break;
115 }
116 case armnn::ResizeMethod::NearestNeighbor:
Teresa Charlin970f43b2019-07-01 13:51:07 +0100117 {
Teresa Charlinda1fb9b2019-07-02 13:25:22 +0100118 // calculate euclidean distance to the 4 neighbours
119 auto distance00 = EuclideanDistance(fix, fiy, x0, y0);
120 auto distance01 = EuclideanDistance(fix, fiy, x0, y1);
121 auto distance10 = EuclideanDistance(fix, fiy, x1, y0);
122 auto distance11 = EuclideanDistance(fix, fiy, x1, y1);
Teresa Charlin970f43b2019-07-01 13:51:07 +0100123
Teresa Charlinda1fb9b2019-07-02 13:25:22 +0100124 auto minimum = std::min( { distance00, distance01, distance10, distance11 } );
125
126 unsigned int xNearest = 0;
127 unsigned int yNearest = 0;
128
129 if (minimum == distance00)
130 {
131 xNearest = x0;
132 yNearest = y0;
133 }
134 else if (minimum == distance01)
135 {
136 xNearest = x0;
137 yNearest = y1;
138 }
139 else if (minimum == distance10)
140 {
141 xNearest = x1;
142 yNearest = y0;
143 }
144 else if (minimum == distance11)
145 {
146 xNearest = x1;
147 yNearest = y1;
148 }
149 else
150 {
151 throw armnn::InvalidArgumentException("Resize Nearest Neighbor failure");
152 }
Teresa Charlin970f43b2019-07-01 13:51:07 +0100153
154 in[dataLayout.GetIndex(inputShape, n, c, yNearest, xNearest)];
155 interpolatedValue = in.Get();
156 break;
157 }
Teresa Charlinda1fb9b2019-07-02 13:25:22 +0100158 default:
159 throw armnn::InvalidArgumentException("Unknown resize method: " +
160 std::to_string(static_cast<int>(resizeMethod)));
Teresa Charlin970f43b2019-07-01 13:51:07 +0100161 }
162 out[dataLayout.GetIndex(outputShape, n, c, y, x)];
163 out.Set(interpolatedValue);
164 }
165 }
166 }
167 }
168}
169
170} //namespace armnn