blob: 349d1ad96acef8657291199d28f0e04d0d1ead82 [file] [log] [blame]
Aron Virginas-Tard089b742019-01-29 11:09:51 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#pragma once
6
7#include "ObjectDetectionCommon.hpp"
8
9#include <memory>
10#include <string>
11#include <vector>
12
13#include <armnn/TypesUtils.hpp>
Narumol Prangnawaratc8bab1b2019-02-15 17:34:51 +000014#include <backendsCommon/test/QuantizeHelper.hpp>
Aron Virginas-Tard089b742019-01-29 11:09:51 +000015
16#include <boost/log/trivial.hpp>
17#include <boost/numeric/conversion/cast.hpp>
18
19#include <array>
20#include <string>
21
22#include "InferenceTestImage.hpp"
23
24namespace
25{
26
27struct MobileNetSsdTestCaseData
28{
29 MobileNetSsdTestCaseData(
Narumol Prangnawarat4628d052019-02-25 17:26:05 +000030 const std::vector<uint8_t>& inputData,
31 const std::vector<DetectedObject>& expectedDetectedObject,
32 const std::vector<std::vector<float>>& expectedOutput)
33 : m_InputData(inputData)
34 , m_ExpectedDetectedObject(expectedDetectedObject)
35 , m_ExpectedOutput(expectedOutput)
Aron Virginas-Tard089b742019-01-29 11:09:51 +000036 {}
37
Narumol Prangnawarat4628d052019-02-25 17:26:05 +000038 std::vector<uint8_t> m_InputData;
39 std::vector<DetectedObject> m_ExpectedDetectedObject;
40 std::vector<std::vector<float>> m_ExpectedOutput;
Aron Virginas-Tard089b742019-01-29 11:09:51 +000041};
42
43class MobileNetSsdDatabase
44{
45public:
Narumol Prangnawaratc8bab1b2019-02-15 17:34:51 +000046 explicit MobileNetSsdDatabase(const std::string& imageDir, float scale, int offset);
Aron Virginas-Tard089b742019-01-29 11:09:51 +000047
48 std::unique_ptr<MobileNetSsdTestCaseData> GetTestCaseData(unsigned int testCaseId);
49
50private:
51 std::string m_ImageDir;
Narumol Prangnawaratc8bab1b2019-02-15 17:34:51 +000052 float m_Scale;
53 int m_Offset;
Aron Virginas-Tard089b742019-01-29 11:09:51 +000054};
55
56constexpr unsigned int k_MobileNetSsdImageWidth = 300u;
57constexpr unsigned int k_MobileNetSsdImageHeight = k_MobileNetSsdImageWidth;
58
59// Test cases
60const std::array<ObjectDetectionInput, 1> g_PerTestCaseInput =
61{
62 ObjectDetectionInput
63 {
64 "Cat.jpg",
Narumol Prangnawarat4628d052019-02-25 17:26:05 +000065 {
Narumol Prangnawarat713e95c2019-06-20 17:08:03 +010066 DetectedObject(16.0f, BoundingBox(0.216785252f, 0.079726994f, 0.927124202f, 0.939067304f), 0.79296875f)
Narumol Prangnawarat4628d052019-02-25 17:26:05 +000067 }
Aron Virginas-Tard089b742019-01-29 11:09:51 +000068 }
69};
70
Narumol Prangnawaratc8bab1b2019-02-15 17:34:51 +000071MobileNetSsdDatabase::MobileNetSsdDatabase(const std::string& imageDir, float scale, int offset)
Aron Virginas-Tard089b742019-01-29 11:09:51 +000072 : m_ImageDir(imageDir)
Narumol Prangnawaratc8bab1b2019-02-15 17:34:51 +000073 , m_Scale(scale)
74 , m_Offset(offset)
Aron Virginas-Tard089b742019-01-29 11:09:51 +000075{}
76
77std::unique_ptr<MobileNetSsdTestCaseData> MobileNetSsdDatabase::GetTestCaseData(unsigned int testCaseId)
78{
79 const unsigned int safeTestCaseId =
80 testCaseId % boost::numeric_cast<unsigned int>(g_PerTestCaseInput.size());
81 const ObjectDetectionInput& testCaseInput = g_PerTestCaseInput[safeTestCaseId];
82
83 // Load test case input
84 const std::string imagePath = m_ImageDir + testCaseInput.first;
Narumol Prangnawaratc8bab1b2019-02-15 17:34:51 +000085 std::vector<uint8_t> imageData;
Aron Virginas-Tard089b742019-01-29 11:09:51 +000086 try
87 {
88 InferenceTestImage image(imagePath.c_str());
89
90 // Resize image (if needed)
91 const unsigned int width = image.GetWidth();
92 const unsigned int height = image.GetHeight();
93 if (width != k_MobileNetSsdImageWidth || height != k_MobileNetSsdImageHeight)
94 {
95 image.Resize(k_MobileNetSsdImageWidth, k_MobileNetSsdImageHeight, CHECK_LOCATION());
96 }
97
98 // Get image data as a vector of floats
Narumol Prangnawaratc8bab1b2019-02-15 17:34:51 +000099 std::vector<float> floatImageData = GetImageDataAsNormalizedFloats(ImageChannelLayout::Rgb, image);
100 imageData = QuantizedVector<uint8_t>(m_Scale, m_Offset, floatImageData);
Aron Virginas-Tard089b742019-01-29 11:09:51 +0000101 }
102 catch (const InferenceTestImageException& e)
103 {
104 BOOST_LOG_TRIVIAL(fatal) << "Failed to load image for test case " << testCaseId << ". Error: " << e.what();
105 return nullptr;
106 }
107
Narumol Prangnawarat4628d052019-02-25 17:26:05 +0000108 std::vector<float> numDetections = { static_cast<float>(testCaseInput.second.size()) };
Aron Virginas-Tard089b742019-01-29 11:09:51 +0000109
Narumol Prangnawarat4628d052019-02-25 17:26:05 +0000110 std::vector<float> detectionBoxes;
111 std::vector<float> detectionClasses;
112 std::vector<float> detectionScores;
113
114 for (DetectedObject expectedObject : testCaseInput.second)
115 {
116 detectionBoxes.push_back(expectedObject.m_BoundingBox.m_YMin);
117 detectionBoxes.push_back(expectedObject.m_BoundingBox.m_XMin);
118 detectionBoxes.push_back(expectedObject.m_BoundingBox.m_YMax);
119 detectionBoxes.push_back(expectedObject.m_BoundingBox.m_XMax);
120
121 detectionClasses.push_back(expectedObject.m_Class);
122
123 detectionScores.push_back(expectedObject.m_Confidence);
124 }
125
126 // Prepare test case expected output
127 std::vector<std::vector<float>> expectedOutputs;
128 expectedOutputs.reserve(4);
129 expectedOutputs.push_back(detectionBoxes);
130 expectedOutputs.push_back(detectionClasses);
131 expectedOutputs.push_back(detectionScores);
132 expectedOutputs.push_back(numDetections);
133
134 return std::make_unique<MobileNetSsdTestCaseData>(imageData, testCaseInput.second, expectedOutputs);
Aron Virginas-Tard089b742019-01-29 11:09:51 +0000135}
136
137} // anonymous namespace