blob: cac558793ff51950dec5603c785c92699f2a42ec [file] [log] [blame]
Aron Virginas-Tard089b742019-01-29 11:09:51 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#pragma once
6
7#include "ObjectDetectionCommon.hpp"
8
9#include <memory>
10#include <string>
11#include <vector>
12
13#include <armnn/TypesUtils.hpp>
Narumol Prangnawaratc8bab1b2019-02-15 17:34:51 +000014#include <backendsCommon/test/QuantizeHelper.hpp>
Aron Virginas-Tard089b742019-01-29 11:09:51 +000015
16#include <boost/log/trivial.hpp>
17#include <boost/numeric/conversion/cast.hpp>
18
19#include <array>
20#include <string>
21
22#include "InferenceTestImage.hpp"
23
24namespace
25{
26
27struct MobileNetSsdTestCaseData
28{
29 MobileNetSsdTestCaseData(
Narumol Prangnawaratc8bab1b2019-02-15 17:34:51 +000030 std::vector<uint8_t> inputData,
Aron Virginas-Tard089b742019-01-29 11:09:51 +000031 std::vector<DetectedObject> expectedOutput)
32 : m_InputData(std::move(inputData))
33 , m_ExpectedOutput(std::move(expectedOutput))
34 {}
35
Narumol Prangnawaratc8bab1b2019-02-15 17:34:51 +000036 std::vector<uint8_t> m_InputData;
Aron Virginas-Tard089b742019-01-29 11:09:51 +000037 std::vector<DetectedObject> m_ExpectedOutput;
38};
39
40class MobileNetSsdDatabase
41{
42public:
Narumol Prangnawaratc8bab1b2019-02-15 17:34:51 +000043 explicit MobileNetSsdDatabase(const std::string& imageDir, float scale, int offset);
Aron Virginas-Tard089b742019-01-29 11:09:51 +000044
45 std::unique_ptr<MobileNetSsdTestCaseData> GetTestCaseData(unsigned int testCaseId);
46
47private:
48 std::string m_ImageDir;
Narumol Prangnawaratc8bab1b2019-02-15 17:34:51 +000049 float m_Scale;
50 int m_Offset;
Aron Virginas-Tard089b742019-01-29 11:09:51 +000051};
52
53constexpr unsigned int k_MobileNetSsdImageWidth = 300u;
54constexpr unsigned int k_MobileNetSsdImageHeight = k_MobileNetSsdImageWidth;
55
56// Test cases
57const std::array<ObjectDetectionInput, 1> g_PerTestCaseInput =
58{
59 ObjectDetectionInput
60 {
61 "Cat.jpg",
Narumol Prangnawaratc8bab1b2019-02-15 17:34:51 +000062 DetectedObject(16, BoundingBox(0.208961248f, 0.0852333307f, 0.92757535f, 0.940263629f), 0.79296875f)
Aron Virginas-Tard089b742019-01-29 11:09:51 +000063 }
64};
65
Narumol Prangnawaratc8bab1b2019-02-15 17:34:51 +000066MobileNetSsdDatabase::MobileNetSsdDatabase(const std::string& imageDir, float scale, int offset)
Aron Virginas-Tard089b742019-01-29 11:09:51 +000067 : m_ImageDir(imageDir)
Narumol Prangnawaratc8bab1b2019-02-15 17:34:51 +000068 , m_Scale(scale)
69 , m_Offset(offset)
Aron Virginas-Tard089b742019-01-29 11:09:51 +000070{}
71
72std::unique_ptr<MobileNetSsdTestCaseData> MobileNetSsdDatabase::GetTestCaseData(unsigned int testCaseId)
73{
74 const unsigned int safeTestCaseId =
75 testCaseId % boost::numeric_cast<unsigned int>(g_PerTestCaseInput.size());
76 const ObjectDetectionInput& testCaseInput = g_PerTestCaseInput[safeTestCaseId];
77
78 // Load test case input
79 const std::string imagePath = m_ImageDir + testCaseInput.first;
Narumol Prangnawaratc8bab1b2019-02-15 17:34:51 +000080 std::vector<uint8_t> imageData;
Aron Virginas-Tard089b742019-01-29 11:09:51 +000081 try
82 {
83 InferenceTestImage image(imagePath.c_str());
84
85 // Resize image (if needed)
86 const unsigned int width = image.GetWidth();
87 const unsigned int height = image.GetHeight();
88 if (width != k_MobileNetSsdImageWidth || height != k_MobileNetSsdImageHeight)
89 {
90 image.Resize(k_MobileNetSsdImageWidth, k_MobileNetSsdImageHeight, CHECK_LOCATION());
91 }
92
93 // Get image data as a vector of floats
Narumol Prangnawaratc8bab1b2019-02-15 17:34:51 +000094 std::vector<float> floatImageData = GetImageDataAsNormalizedFloats(ImageChannelLayout::Rgb, image);
95 imageData = QuantizedVector<uint8_t>(m_Scale, m_Offset, floatImageData);
Aron Virginas-Tard089b742019-01-29 11:09:51 +000096 }
97 catch (const InferenceTestImageException& e)
98 {
99 BOOST_LOG_TRIVIAL(fatal) << "Failed to load image for test case " << testCaseId << ". Error: " << e.what();
100 return nullptr;
101 }
102
103 // Prepare test case expected output
104 std::vector<DetectedObject> expectedOutput;
105 expectedOutput.reserve(1);
106 expectedOutput.push_back(testCaseInput.second);
107
108 return std::make_unique<MobileNetSsdTestCaseData>(std::move(imageData), std::move(expectedOutput));
109}
110
111} // anonymous namespace