blob: 205460a2f27ddc50956dafb8fc8bbe0e950db6f4 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// See LICENSE file in the project root for full license information.
4//
5#include "InferenceTestImage.hpp"
6
7#include <boost/core/ignore_unused.hpp>
8#include <boost/format.hpp>
9#include <boost/core/ignore_unused.hpp>
10#include <boost/numeric/conversion/cast.hpp>
11
12#include <array>
13
14#define STB_IMAGE_IMPLEMENTATION
15#include <stb_image.h>
16
17#define STB_IMAGE_RESIZE_IMPLEMENTATION
18#include <stb_image_resize.h>
19
20#define STB_IMAGE_WRITE_IMPLEMENTATION
21#include <stb_image_write.h>
22
23namespace
24{
25
26unsigned int GetImageChannelIndex(ImageChannelLayout channelLayout, ImageChannel channel)
27{
28 switch (channelLayout)
29 {
30 case ImageChannelLayout::Rgb:
31 return static_cast<unsigned int>(channel);
32 case ImageChannelLayout::Bgr:
33 return 2u - static_cast<unsigned int>(channel);
34 default:
35 throw UnknownImageChannelLayout(boost::str(boost::format("Unknown layout %1%")
36 % static_cast<int>(channelLayout)));
37 }
38}
39
40} // namespace
41
42InferenceTestImage::InferenceTestImage(char const* filePath)
43 : m_Width(0u)
44 , m_Height(0u)
45 , m_NumChannels(0u)
46{
47 int width;
48 int height;
49 int channels;
50
51 using StbImageDataPtr = std::unique_ptr<unsigned char, decltype(&stbi_image_free)>;
52 StbImageDataPtr stbData(stbi_load(filePath, &width, &height, &channels, 0), &stbi_image_free);
53
54 if (stbData == nullptr)
55 {
56 throw InferenceTestImageLoadFailed(boost::str(boost::format("Could not load the image at %1%") % filePath));
57 }
58
59 if (width == 0 || height == 0)
60 {
61 throw InferenceTestImageLoadFailed(boost::str(boost::format("Could not load empty image at %1%") % filePath));
62 }
63
64 m_Width = boost::numeric_cast<unsigned int>(width);
65 m_Height = boost::numeric_cast<unsigned int>(height);
66 m_NumChannels = boost::numeric_cast<unsigned int>(channels);
67
68 const unsigned int sizeInBytes = GetSizeInBytes();
69 m_Data.resize(sizeInBytes);
70 memcpy(m_Data.data(), stbData.get(), sizeInBytes);
71}
72
73std::tuple<uint8_t, uint8_t, uint8_t> InferenceTestImage::GetPixelAs3Channels(unsigned int x, unsigned int y) const
74{
75 if (x >= m_Width || y >= m_Height)
76 {
77 throw InferenceTestImageOutOfBoundsAccess(boost::str(boost::format("Attempted out of bounds image access. "
78 "Requested (%1%, %2%). Maximum valid coordinates (%3%, %4%).") % x % y % (m_Width - 1) % (m_Height - 1)));
79 }
80
81 const unsigned int pixelOffset = x * GetNumChannels() + y * GetWidth() * GetNumChannels();
82 const uint8_t* const pixelData = m_Data.data() + pixelOffset;
83 BOOST_ASSERT(pixelData <= (m_Data.data() + GetSizeInBytes()));
84
85 std::array<uint8_t, 3> outPixelData;
86 outPixelData.fill(0);
87
88 const unsigned int maxChannelsInPixel = std::min(GetNumChannels(), static_cast<unsigned int>(outPixelData.size()));
89 for (unsigned int c = 0; c < maxChannelsInPixel; ++c)
90 {
91 outPixelData[c] = pixelData[c];
92 }
93
94 return std::make_tuple(outPixelData[0], outPixelData[1], outPixelData[2]);
95}
96
97void InferenceTestImage::Resize(unsigned int newWidth, unsigned int newHeight)
98{
99 if (newWidth == 0 || newHeight == 0)
100 {
101 throw InferenceTestImageResizeFailed(boost::str(boost::format("None of the dimensions passed to a resize "
102 "operation can be zero. Requested width: %1%. Requested height: %2%.") % newWidth % newHeight));
103 }
104
105 if (newWidth == m_Width && newHeight == m_Height)
106 {
107 // nothing to do
108 return;
109 }
110
111 std::vector<uint8_t> newData;
112 newData.resize(newWidth * newHeight * GetNumChannels() * GetSingleElementSizeInBytes());
113
114 // boost::numeric_cast<>() is used for user-provided data (protecting about overflows).
115 // static_cast<> ok for internal data (assumes that, when internal data was originally provided by a user,
116 // a boost::numeric_cast<>() handled the conversion).
117 const int nW = boost::numeric_cast<int>(newWidth);
118 const int nH = boost::numeric_cast<int>(newHeight);
119
120 const int w = static_cast<int>(GetWidth());
121 const int h = static_cast<int>(GetHeight());
122 const int numChannels = static_cast<int>(GetNumChannels());
123
124 const int res = stbir_resize_uint8(m_Data.data(), w, h, 0, newData.data(), nW, nH, 0, numChannels);
125 if (res == 0)
126 {
127 throw InferenceTestImageResizeFailed("The resizing operation failed");
128 }
129
130 m_Data.swap(newData);
131 m_Width = newWidth;
132 m_Height = newHeight;
133}
134
135void InferenceTestImage::Write(WriteFormat format, const char* filePath) const
136{
137 const int w = static_cast<int>(GetWidth());
138 const int h = static_cast<int>(GetHeight());
139 const int numChannels = static_cast<int>(GetNumChannels());
140 int res = 0;
141
142 switch (format)
143 {
144 case WriteFormat::Png:
145 {
146 res = stbi_write_png(filePath, w, h, numChannels, m_Data.data(), 0);
147 break;
148 }
149 case WriteFormat::Bmp:
150 {
151 res = stbi_write_bmp(filePath, w, h, numChannels, m_Data.data());
152 break;
153 }
154 case WriteFormat::Tga:
155 {
156 res = stbi_write_tga(filePath, w, h, numChannels, m_Data.data());
157 break;
158 }
159 default:
160 throw InferenceTestImageWriteFailed(boost::str(boost::format("Unknown format %1%")
161 % static_cast<int>(format)));
162 }
163
164 if (res == 0)
165 {
166 throw InferenceTestImageWriteFailed(boost::str(boost::format("An error occurred when writing to file %1%")
167 % filePath));
168 }
169}
170
171template <typename TProcessValueCallable>
172std::vector<float> GetImageDataInArmNnLayoutAsFloats(ImageChannelLayout channelLayout,
173 const InferenceTestImage& image,
174 TProcessValueCallable processValue)
175{
176 const unsigned int h = image.GetHeight();
177 const unsigned int w = image.GetWidth();
178
179 std::vector<float> imageData;
180 imageData.resize(h * w * 3);
181
182 for (unsigned int j = 0; j < h; ++j)
183 {
184 for (unsigned int i = 0; i < w; ++i)
185 {
186 uint8_t r, g, b;
187 std::tie(r, g, b) = image.GetPixelAs3Channels(i, j);
188
189 // ArmNN order: C, H, W
190 const unsigned int rDstIndex = GetImageChannelIndex(channelLayout, ImageChannel::R) * h * w + j * w + i;
191 const unsigned int gDstIndex = GetImageChannelIndex(channelLayout, ImageChannel::G) * h * w + j * w + i;
192 const unsigned int bDstIndex = GetImageChannelIndex(channelLayout, ImageChannel::B) * h * w + j * w + i;
193
194 imageData[rDstIndex] = processValue(ImageChannel::R, float(r));
195 imageData[gDstIndex] = processValue(ImageChannel::G, float(g));
196 imageData[bDstIndex] = processValue(ImageChannel::B, float(b));
197 }
198 }
199
200 return imageData;
201}
202
203std::vector<float> GetImageDataInArmNnLayoutAsNormalizedFloats(ImageChannelLayout layout,
204 const InferenceTestImage& image)
205{
206 return GetImageDataInArmNnLayoutAsFloats(layout, image,
207 [](ImageChannel channel, float value)
208 {
209 boost::ignore_unused(channel);
210 return value / 255.f;
211 });
212}
213
214std::vector<float> GetImageDataInArmNnLayoutAsFloatsSubtractingMean(ImageChannelLayout layout,
215 const InferenceTestImage& image,
216 const std::array<float, 3>& mean)
217{
218 return GetImageDataInArmNnLayoutAsFloats(layout, image,
219 [layout, &mean](ImageChannel channel, float value)
220 {
221 const unsigned int channelIndex = GetImageChannelIndex(layout, channel);
222 return value - mean[channelIndex];
223 });
224}
surmeh01bceff2f2018-03-29 16:29:27 +0100225
226std::vector<float> GetImageDataAsNormalizedFloats(ImageChannelLayout layout,
227 const InferenceTestImage& image)
228{
229 std::vector<float> imageData;
230 const unsigned int h = image.GetHeight();
231 const unsigned int w = image.GetWidth();
232
233 const unsigned int rDstIndex = GetImageChannelIndex(layout, ImageChannel::R);
234 const unsigned int gDstIndex = GetImageChannelIndex(layout, ImageChannel::G);
235 const unsigned int bDstIndex = GetImageChannelIndex(layout, ImageChannel::B);
236
237 imageData.resize(h * w * 3);
238 unsigned int offset = 0;
239
240 for (unsigned int j = 0; j < h; ++j)
241 {
242 for (unsigned int i = 0; i < w; ++i)
243 {
244 uint8_t r, g, b;
245 std::tie(r, g, b) = image.GetPixelAs3Channels(i, j);
246
247 imageData[offset+rDstIndex] = float(r) / 255.0f;
248 imageData[offset+gDstIndex] = float(g) / 255.0f;
249 imageData[offset+bDstIndex] = float(b) / 255.0f;
250 offset += 3;
251 }
252 }
253
254 return imageData;
255}