blob: 0a235c9a3ef98174450d0977cc2cd01b0232f779 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// See LICENSE file in the project root for full license information.
4//
5#include "InferenceTestImage.hpp"
6#include "ImageNetDatabase.hpp"
7
8#include <boost/numeric/conversion/cast.hpp>
9#include <boost/log/trivial.hpp>
10#include <boost/assert.hpp>
11#include <boost/format.hpp>
12
13#include <iostream>
14#include <fcntl.h>
15#include <array>
16
17const std::vector<ImageSet> g_DefaultImageSet =
18{
19 {"shark.jpg", 2}
20};
21
22ImageNetDatabase::ImageNetDatabase(const std::string& binaryFileDirectory, unsigned int width, unsigned int height,
23 const std::vector<ImageSet>& imageSet)
24: m_BinaryDirectory(binaryFileDirectory)
25, m_Height(height)
26, m_Width(width)
27, m_ImageSet(imageSet.empty() ? g_DefaultImageSet : imageSet)
28{
29}
30
31std::unique_ptr<ImageNetDatabase::TTestCaseData> ImageNetDatabase::GetTestCaseData(unsigned int testCaseId)
32{
33 testCaseId = testCaseId % boost::numeric_cast<unsigned int>(m_ImageSet.size());
34 const ImageSet& imageSet = m_ImageSet[testCaseId];
35 const std::string fullPath = m_BinaryDirectory + imageSet.first;
36 FILE* file = fopen(fullPath.c_str(), "rb");
37
38 if (file == nullptr)
39 {
40 BOOST_LOG_TRIVIAL(fatal) << "Failed to load " << fullPath;
41 return nullptr;
42 }
43
44 InferenceTestImage image(fullPath.c_str());
45 image.Resize(m_Width, m_Height);
46
47 // The model expects image data in BGR format
48 std::vector<float> inputImageData = GetImageDataInArmNnLayoutAsFloatsSubtractingMean(ImageChannelLayout::Bgr,
49 image, m_MeanBgr);
50
51 // list of labels: https://gist.github.com/yrevar/942d3a0ac09ec9e5eb3a
52 const unsigned int label = imageSet.second;
53 return std::make_unique<TTestCaseData>(label, std::move(inputImageData));
54}