blob: 1a99ed715a3978d9dbb372cfc5be8746d286f61a [file] [log] [blame]
Aron Virginas-Tard089b742019-01-29 11:09:51 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
Aron Virginas-Tar48623a02019-10-22 10:00:28 +01005
Aron Virginas-Tard089b742019-01-29 11:09:51 +00006#pragma once
7
Aron Virginas-Tar48623a02019-10-22 10:00:28 +01008#include "InferenceTestImage.hpp"
Aron Virginas-Tard089b742019-01-29 11:09:51 +00009#include "ObjectDetectionCommon.hpp"
10
Aron Virginas-Tar48623a02019-10-22 10:00:28 +010011#include <QuantizeHelper.hpp>
Aron Virginas-Tard089b742019-01-29 11:09:51 +000012
13#include <armnn/TypesUtils.hpp>
14
15#include <boost/log/trivial.hpp>
16#include <boost/numeric/conversion/cast.hpp>
17
18#include <array>
Aron Virginas-Tar48623a02019-10-22 10:00:28 +010019#include <memory>
Aron Virginas-Tard089b742019-01-29 11:09:51 +000020#include <string>
Aron Virginas-Tar48623a02019-10-22 10:00:28 +010021#include <vector>
Aron Virginas-Tard089b742019-01-29 11:09:51 +000022
23namespace
24{
25
26struct MobileNetSsdTestCaseData
27{
28 MobileNetSsdTestCaseData(
Narumol Prangnawarat4628d052019-02-25 17:26:05 +000029 const std::vector<uint8_t>& inputData,
30 const std::vector<DetectedObject>& expectedDetectedObject,
31 const std::vector<std::vector<float>>& expectedOutput)
32 : m_InputData(inputData)
33 , m_ExpectedDetectedObject(expectedDetectedObject)
34 , m_ExpectedOutput(expectedOutput)
Aron Virginas-Tard089b742019-01-29 11:09:51 +000035 {}
36
Narumol Prangnawarat4628d052019-02-25 17:26:05 +000037 std::vector<uint8_t> m_InputData;
38 std::vector<DetectedObject> m_ExpectedDetectedObject;
39 std::vector<std::vector<float>> m_ExpectedOutput;
Aron Virginas-Tard089b742019-01-29 11:09:51 +000040};
41
42class MobileNetSsdDatabase
43{
44public:
Narumol Prangnawaratc8bab1b2019-02-15 17:34:51 +000045 explicit MobileNetSsdDatabase(const std::string& imageDir, float scale, int offset);
Aron Virginas-Tard089b742019-01-29 11:09:51 +000046
47 std::unique_ptr<MobileNetSsdTestCaseData> GetTestCaseData(unsigned int testCaseId);
48
49private:
50 std::string m_ImageDir;
Narumol Prangnawaratc8bab1b2019-02-15 17:34:51 +000051 float m_Scale;
52 int m_Offset;
Aron Virginas-Tard089b742019-01-29 11:09:51 +000053};
54
55constexpr unsigned int k_MobileNetSsdImageWidth = 300u;
56constexpr unsigned int k_MobileNetSsdImageHeight = k_MobileNetSsdImageWidth;
57
58// Test cases
59const std::array<ObjectDetectionInput, 1> g_PerTestCaseInput =
60{
61 ObjectDetectionInput
62 {
63 "Cat.jpg",
Narumol Prangnawarat4628d052019-02-25 17:26:05 +000064 {
Narumol Prangnawarat713e95c2019-06-20 17:08:03 +010065 DetectedObject(16.0f, BoundingBox(0.216785252f, 0.079726994f, 0.927124202f, 0.939067304f), 0.79296875f)
Narumol Prangnawarat4628d052019-02-25 17:26:05 +000066 }
Aron Virginas-Tard089b742019-01-29 11:09:51 +000067 }
68};
69
Narumol Prangnawaratc8bab1b2019-02-15 17:34:51 +000070MobileNetSsdDatabase::MobileNetSsdDatabase(const std::string& imageDir, float scale, int offset)
Aron Virginas-Tard089b742019-01-29 11:09:51 +000071 : m_ImageDir(imageDir)
Narumol Prangnawaratc8bab1b2019-02-15 17:34:51 +000072 , m_Scale(scale)
73 , m_Offset(offset)
Aron Virginas-Tard089b742019-01-29 11:09:51 +000074{}
75
76std::unique_ptr<MobileNetSsdTestCaseData> MobileNetSsdDatabase::GetTestCaseData(unsigned int testCaseId)
77{
78 const unsigned int safeTestCaseId =
79 testCaseId % boost::numeric_cast<unsigned int>(g_PerTestCaseInput.size());
80 const ObjectDetectionInput& testCaseInput = g_PerTestCaseInput[safeTestCaseId];
81
82 // Load test case input
83 const std::string imagePath = m_ImageDir + testCaseInput.first;
Narumol Prangnawaratc8bab1b2019-02-15 17:34:51 +000084 std::vector<uint8_t> imageData;
Aron Virginas-Tard089b742019-01-29 11:09:51 +000085 try
86 {
87 InferenceTestImage image(imagePath.c_str());
88
89 // Resize image (if needed)
90 const unsigned int width = image.GetWidth();
91 const unsigned int height = image.GetHeight();
92 if (width != k_MobileNetSsdImageWidth || height != k_MobileNetSsdImageHeight)
93 {
94 image.Resize(k_MobileNetSsdImageWidth, k_MobileNetSsdImageHeight, CHECK_LOCATION());
95 }
96
97 // Get image data as a vector of floats
Narumol Prangnawaratc8bab1b2019-02-15 17:34:51 +000098 std::vector<float> floatImageData = GetImageDataAsNormalizedFloats(ImageChannelLayout::Rgb, image);
Aron Virginas-Tar48623a02019-10-22 10:00:28 +010099 imageData = armnnUtils::QuantizedVector<uint8_t>(floatImageData, m_Scale, m_Offset);
Aron Virginas-Tard089b742019-01-29 11:09:51 +0000100 }
101 catch (const InferenceTestImageException& e)
102 {
103 BOOST_LOG_TRIVIAL(fatal) << "Failed to load image for test case " << testCaseId << ". Error: " << e.what();
104 return nullptr;
105 }
106
Narumol Prangnawarat4628d052019-02-25 17:26:05 +0000107 std::vector<float> numDetections = { static_cast<float>(testCaseInput.second.size()) };
Aron Virginas-Tard089b742019-01-29 11:09:51 +0000108
Narumol Prangnawarat4628d052019-02-25 17:26:05 +0000109 std::vector<float> detectionBoxes;
110 std::vector<float> detectionClasses;
111 std::vector<float> detectionScores;
112
113 for (DetectedObject expectedObject : testCaseInput.second)
114 {
115 detectionBoxes.push_back(expectedObject.m_BoundingBox.m_YMin);
116 detectionBoxes.push_back(expectedObject.m_BoundingBox.m_XMin);
117 detectionBoxes.push_back(expectedObject.m_BoundingBox.m_YMax);
118 detectionBoxes.push_back(expectedObject.m_BoundingBox.m_XMax);
119
120 detectionClasses.push_back(expectedObject.m_Class);
121
122 detectionScores.push_back(expectedObject.m_Confidence);
123 }
124
125 // Prepare test case expected output
126 std::vector<std::vector<float>> expectedOutputs;
127 expectedOutputs.reserve(4);
128 expectedOutputs.push_back(detectionBoxes);
129 expectedOutputs.push_back(detectionClasses);
130 expectedOutputs.push_back(detectionScores);
131 expectedOutputs.push_back(numDetections);
132
133 return std::make_unique<MobileNetSsdTestCaseData>(imageData, testCaseInput.second, expectedOutputs);
Aron Virginas-Tard089b742019-01-29 11:09:51 +0000134}
135
136} // anonymous namespace