blob: 5e2a4820facfcfe4a9b6998cff75b4958efa507b [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
6
7#include "InferenceTest.hpp"
8#include "YoloDatabase.hpp"
9
10#include <algorithm>
11#include <array>
12#include <utility>
13
14#include <boost/assert.hpp>
15#include <boost/multi_array.hpp>
16#include <boost/test/tools/floating_point_comparison.hpp>
17
18constexpr size_t YoloOutputSize = 1470;
19
20template <typename Model>
21class YoloTestCase : public InferenceModelTestCase<Model>
22{
23public:
24 YoloTestCase(Model& model,
25 unsigned int testCaseId,
26 YoloTestCaseData& testCaseData)
27 : InferenceModelTestCase<Model>(model, testCaseId, std::move(testCaseData.m_InputImage), YoloOutputSize)
28 , m_FloatComparer(boost::math::fpc::percent_tolerance(1.0f))
29 , m_TopObjectDetections(std::move(testCaseData.m_TopObjectDetections))
30 {
31 }
32
33 virtual TestCaseResult ProcessResult(const InferenceTestOptions& options) override
34 {
35 using Boost3dArray = boost::multi_array<float, 3>;
36
37 const std::vector<float>& output = this->GetOutput();
38 BOOST_ASSERT(output.size() == YoloOutputSize);
39
40 constexpr Boost3dArray::index gridSize = 7;
41 constexpr Boost3dArray::index numClasses = 20;
42 constexpr Boost3dArray::index numScales = 2;
43
44 const float* outputPtr = output.data();
45
46 // Range 0-980. Class probabilities. 7x7x20
47 Boost3dArray classProbabilities(boost::extents[gridSize][gridSize][numClasses]);
48 for (Boost3dArray::index y = 0; y < gridSize; ++y)
49 {
50 for (Boost3dArray::index x = 0; x < gridSize; ++x)
51 {
52 for (Boost3dArray::index c = 0; c < numClasses; ++c)
53 {
54 classProbabilities[y][x][c] = *outputPtr++;
55 }
56 }
57 }
58
59 // Range 980-1078. Scales. 7x7x2
60 Boost3dArray scales(boost::extents[gridSize][gridSize][numScales]);
61 for (Boost3dArray::index y = 0; y < gridSize; ++y)
62 {
63 for (Boost3dArray::index x = 0; x < gridSize; ++x)
64 {
65 for (Boost3dArray::index s = 0; s < numScales; ++s)
66 {
67 scales[y][x][s] = *outputPtr++;
68 }
69 }
70 }
71
72 // Range 1078-1469. Bounding boxes. 7x7x2x4
73 constexpr float imageWidthAsFloat = static_cast<float>(YoloImageWidth);
74 constexpr float imageHeightAsFloat = static_cast<float>(YoloImageHeight);
75
76 boost::multi_array<float, 4> boxes(boost::extents[gridSize][gridSize][numScales][4]);
77 for (Boost3dArray::index y = 0; y < gridSize; ++y)
78 {
79 for (Boost3dArray::index x = 0; x < gridSize; ++x)
80 {
81 for (Boost3dArray::index s = 0; s < numScales; ++s)
82 {
83 float bx = *outputPtr++;
84 float by = *outputPtr++;
85 float bw = *outputPtr++;
86 float bh = *outputPtr++;
87
88 boxes[y][x][s][0] = ((bx + static_cast<float>(x)) / 7.0f) * imageWidthAsFloat;
89 boxes[y][x][s][1] = ((by + static_cast<float>(y)) / 7.0f) * imageHeightAsFloat;
90 boxes[y][x][s][2] = bw * bw * static_cast<float>(imageWidthAsFloat);
91 boxes[y][x][s][3] = bh * bh * static_cast<float>(imageHeightAsFloat);
92 }
93 }
94 }
95 BOOST_ASSERT(output.data() + YoloOutputSize == outputPtr);
96
97 std::vector<YoloDetectedObject> detectedObjects;
98 detectedObjects.reserve(gridSize * gridSize * numScales * numClasses);
99
100 for (Boost3dArray::index y = 0; y < gridSize; ++y)
101 {
102 for (Boost3dArray::index x = 0; x < gridSize; ++x)
103 {
104 for (Boost3dArray::index s = 0; s < numScales; ++s)
105 {
106 for (Boost3dArray::index c = 0; c < numClasses; ++c)
107 {
telsoa01c577f2c2018-08-31 09:22:23 +0100108 // Resolved confidence: class probabilities * scales.
telsoa014fcda012018-03-09 14:13:49 +0000109 const float confidence = classProbabilities[y][x][c] * scales[y][x][s];
110
telsoa01c577f2c2018-08-31 09:22:23 +0100111 // Resolves bounding box and stores.
telsoa014fcda012018-03-09 14:13:49 +0000112 YoloBoundingBox box;
113 box.m_X = boxes[y][x][s][0];
114 box.m_Y = boxes[y][x][s][1];
115 box.m_W = boxes[y][x][s][2];
116 box.m_H = boxes[y][x][s][3];
117
118 detectedObjects.emplace_back(c, box, confidence);
119 }
120 }
121 }
122 }
123
telsoa01c577f2c2018-08-31 09:22:23 +0100124 // Sorts detected objects by confidence.
telsoa014fcda012018-03-09 14:13:49 +0000125 std::sort(detectedObjects.begin(), detectedObjects.end(),
126 [](const YoloDetectedObject& a, const YoloDetectedObject& b)
127 {
telsoa01c577f2c2018-08-31 09:22:23 +0100128 // Sorts by largest confidence first, then by class.
telsoa014fcda012018-03-09 14:13:49 +0000129 return a.m_Confidence > b.m_Confidence
130 || (a.m_Confidence == b.m_Confidence && a.m_Class > b.m_Class);
131 });
132
telsoa01c577f2c2018-08-31 09:22:23 +0100133 // Checks the top N detections.
telsoa014fcda012018-03-09 14:13:49 +0000134 auto outputIt = detectedObjects.begin();
135 auto outputEnd = detectedObjects.end();
136
137 for (const YoloDetectedObject& expectedDetection : m_TopObjectDetections)
138 {
139 if (outputIt == outputEnd)
140 {
telsoa01c577f2c2018-08-31 09:22:23 +0100141 // Somehow expected more things to check than detections found by the model.
telsoa014fcda012018-03-09 14:13:49 +0000142 return TestCaseResult::Abort;
143 }
144
145 const YoloDetectedObject& detectedObject = *outputIt;
146 if (detectedObject.m_Class != expectedDetection.m_Class)
147 {
148 BOOST_LOG_TRIVIAL(error) << "Prediction for test case " << this->GetTestCaseId() <<
James Conroyca225f02018-09-18 17:06:44 +0100149 " is incorrect: Expected (" << expectedDetection.m_Class << ")" <<
150 " but predicted (" << detectedObject.m_Class << ")";
telsoa014fcda012018-03-09 14:13:49 +0000151 return TestCaseResult::Failed;
152 }
153
154 if (!m_FloatComparer(detectedObject.m_Box.m_X, expectedDetection.m_Box.m_X) ||
155 !m_FloatComparer(detectedObject.m_Box.m_Y, expectedDetection.m_Box.m_Y) ||
156 !m_FloatComparer(detectedObject.m_Box.m_W, expectedDetection.m_Box.m_W) ||
157 !m_FloatComparer(detectedObject.m_Box.m_H, expectedDetection.m_Box.m_H) ||
158 !m_FloatComparer(detectedObject.m_Confidence, expectedDetection.m_Confidence))
159 {
160 BOOST_LOG_TRIVIAL(error) << "Detected bounding box for test case " << this->GetTestCaseId() <<
161 " is incorrect";
162 return TestCaseResult::Failed;
163 }
164
165 ++outputIt;
166 }
167
168 return TestCaseResult::Ok;
169 }
170
171private:
172 boost::math::fpc::close_at_tolerance<float> m_FloatComparer;
173 std::vector<YoloDetectedObject> m_TopObjectDetections;
174};
175
176template <typename Model>
177class YoloTestCaseProvider : public IInferenceTestCaseProvider
178{
179public:
180 template <typename TConstructModelCallable>
181 YoloTestCaseProvider(TConstructModelCallable constructModel)
182 : m_ConstructModel(constructModel)
183 {
184 }
185
186 virtual void AddCommandLineOptions(boost::program_options::options_description& options) override
187 {
188 namespace po = boost::program_options;
189
190 options.add_options()
191 ("data-dir,d", po::value<std::string>(&m_DataDir)->required(),
192 "Path to directory containing test data");
193
194 Model::AddCommandLineOptions(options, m_ModelCommandLineOptions);
195 }
196
197 virtual bool ProcessCommandLineOptions() override
198 {
199 if (!ValidateDirectory(m_DataDir))
200 {
201 return false;
202 }
203
204 m_Model = m_ConstructModel(m_ModelCommandLineOptions);
205 if (!m_Model)
206 {
207 return false;
208 }
209
210 m_Database = std::make_unique<YoloDatabase>(m_DataDir.c_str());
211 if (!m_Database)
212 {
213 return false;
214 }
215
216 return true;
217 }
218
219 virtual std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) override
220 {
221 std::unique_ptr<YoloTestCaseData> testCaseData = m_Database->GetTestCaseData(testCaseId);
222 if (!testCaseData)
223 {
224 return nullptr;
225 }
226
227 return std::make_unique<YoloTestCase<Model>>(*m_Model, testCaseId, *testCaseData);
228 }
229
230private:
231 typename Model::CommandLineOptions m_ModelCommandLineOptions;
232 std::function<std::unique_ptr<Model>(typename Model::CommandLineOptions)> m_ConstructModel;
233 std::unique_ptr<Model> m_Model;
234
235 std::string m_DataDir;
236 std::unique_ptr<YoloDatabase> m_Database;
237};