blob: e3a28d13bd51c6e90d5bd04e80f1e64487e522da [file] [log] [blame]
Aron Virginas-Tard089b742019-01-29 11:09:51 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#pragma once
6
7#include "ObjectDetectionCommon.hpp"
8
9#include <memory>
10#include <string>
11#include <vector>
12
13#include <armnn/TypesUtils.hpp>
14
15#include <boost/log/trivial.hpp>
16#include <boost/numeric/conversion/cast.hpp>
17
18#include <array>
19#include <string>
20
21#include "InferenceTestImage.hpp"
22
23namespace
24{
25
26struct MobileNetSsdTestCaseData
27{
28 MobileNetSsdTestCaseData(
29 std::vector<float> inputData,
30 std::vector<DetectedObject> expectedOutput)
31 : m_InputData(std::move(inputData))
32 , m_ExpectedOutput(std::move(expectedOutput))
33 {}
34
35 std::vector<float> m_InputData;
36 std::vector<DetectedObject> m_ExpectedOutput;
37};
38
39class MobileNetSsdDatabase
40{
41public:
42 explicit MobileNetSsdDatabase(const std::string& imageDir);
43
44 std::unique_ptr<MobileNetSsdTestCaseData> GetTestCaseData(unsigned int testCaseId);
45
46private:
47 std::string m_ImageDir;
48};
49
50constexpr unsigned int k_MobileNetSsdImageWidth = 300u;
51constexpr unsigned int k_MobileNetSsdImageHeight = k_MobileNetSsdImageWidth;
52
53// Test cases
54const std::array<ObjectDetectionInput, 1> g_PerTestCaseInput =
55{
56 ObjectDetectionInput
57 {
58 "Cat.jpg",
59 DetectedObject(16, BoundingBox(0.21678525f, 0.0859828f, 0.9271242f, 0.9453231f), 0.79296875f)
60 }
61};
62
63MobileNetSsdDatabase::MobileNetSsdDatabase(const std::string& imageDir)
64 : m_ImageDir(imageDir)
65{}
66
67std::unique_ptr<MobileNetSsdTestCaseData> MobileNetSsdDatabase::GetTestCaseData(unsigned int testCaseId)
68{
69 const unsigned int safeTestCaseId =
70 testCaseId % boost::numeric_cast<unsigned int>(g_PerTestCaseInput.size());
71 const ObjectDetectionInput& testCaseInput = g_PerTestCaseInput[safeTestCaseId];
72
73 // Load test case input
74 const std::string imagePath = m_ImageDir + testCaseInput.first;
75 std::vector<float> imageData;
76 try
77 {
78 InferenceTestImage image(imagePath.c_str());
79
80 // Resize image (if needed)
81 const unsigned int width = image.GetWidth();
82 const unsigned int height = image.GetHeight();
83 if (width != k_MobileNetSsdImageWidth || height != k_MobileNetSsdImageHeight)
84 {
85 image.Resize(k_MobileNetSsdImageWidth, k_MobileNetSsdImageHeight, CHECK_LOCATION());
86 }
87
88 // Get image data as a vector of floats
89 imageData = GetImageDataInArmNnLayoutAsNormalizedFloats(ImageChannelLayout::Rgb, image);
90 }
91 catch (const InferenceTestImageException& e)
92 {
93 BOOST_LOG_TRIVIAL(fatal) << "Failed to load image for test case " << testCaseId << ". Error: " << e.what();
94 return nullptr;
95 }
96
97 // Prepare test case expected output
98 std::vector<DetectedObject> expectedOutput;
99 expectedOutput.reserve(1);
100 expectedOutput.push_back(testCaseInput.second);
101
102 return std::make_unique<MobileNetSsdTestCaseData>(std::move(imageData), std::move(expectedOutput));
103}
104
105} // anonymous namespace