blob: 4190e723654e861c17a37859ec979aa8c8cc72ef [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
6
7#include "InferenceTest.hpp"
8#include "YoloDatabase.hpp"
9
Jan Eilers8eb25602020-03-09 12:13:48 +000010#include <armnn/utility/IgnoreUnused.hpp>
11
telsoa014fcda012018-03-09 14:13:49 +000012#include <algorithm>
13#include <array>
14#include <utility>
15
16#include <boost/assert.hpp>
17#include <boost/multi_array.hpp>
18#include <boost/test/tools/floating_point_comparison.hpp>
19
20constexpr size_t YoloOutputSize = 1470;
21
22template <typename Model>
23class YoloTestCase : public InferenceModelTestCase<Model>
24{
25public:
26 YoloTestCase(Model& model,
27 unsigned int testCaseId,
28 YoloTestCaseData& testCaseData)
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +000029 : InferenceModelTestCase<Model>(model, testCaseId, { std::move(testCaseData.m_InputImage) }, { YoloOutputSize })
telsoa014fcda012018-03-09 14:13:49 +000030 , m_FloatComparer(boost::math::fpc::percent_tolerance(1.0f))
31 , m_TopObjectDetections(std::move(testCaseData.m_TopObjectDetections))
32 {
33 }
34
35 virtual TestCaseResult ProcessResult(const InferenceTestOptions& options) override
36 {
Jan Eilers8eb25602020-03-09 12:13:48 +000037 armnn::IgnoreUnused(options);
Derek Lambertieb1fce02019-12-10 21:20:10 +000038
telsoa014fcda012018-03-09 14:13:49 +000039 using Boost3dArray = boost::multi_array<float, 3>;
40
Ferran Balaguerc602f292019-02-08 17:09:55 +000041 const std::vector<float>& output = boost::get<std::vector<float>>(this->GetOutputs()[0]);
telsoa014fcda012018-03-09 14:13:49 +000042 BOOST_ASSERT(output.size() == YoloOutputSize);
43
44 constexpr Boost3dArray::index gridSize = 7;
45 constexpr Boost3dArray::index numClasses = 20;
46 constexpr Boost3dArray::index numScales = 2;
47
48 const float* outputPtr = output.data();
49
50 // Range 0-980. Class probabilities. 7x7x20
51 Boost3dArray classProbabilities(boost::extents[gridSize][gridSize][numClasses]);
52 for (Boost3dArray::index y = 0; y < gridSize; ++y)
53 {
54 for (Boost3dArray::index x = 0; x < gridSize; ++x)
55 {
56 for (Boost3dArray::index c = 0; c < numClasses; ++c)
57 {
58 classProbabilities[y][x][c] = *outputPtr++;
59 }
60 }
61 }
62
63 // Range 980-1078. Scales. 7x7x2
64 Boost3dArray scales(boost::extents[gridSize][gridSize][numScales]);
65 for (Boost3dArray::index y = 0; y < gridSize; ++y)
66 {
67 for (Boost3dArray::index x = 0; x < gridSize; ++x)
68 {
69 for (Boost3dArray::index s = 0; s < numScales; ++s)
70 {
71 scales[y][x][s] = *outputPtr++;
72 }
73 }
74 }
75
76 // Range 1078-1469. Bounding boxes. 7x7x2x4
77 constexpr float imageWidthAsFloat = static_cast<float>(YoloImageWidth);
78 constexpr float imageHeightAsFloat = static_cast<float>(YoloImageHeight);
79
80 boost::multi_array<float, 4> boxes(boost::extents[gridSize][gridSize][numScales][4]);
81 for (Boost3dArray::index y = 0; y < gridSize; ++y)
82 {
83 for (Boost3dArray::index x = 0; x < gridSize; ++x)
84 {
85 for (Boost3dArray::index s = 0; s < numScales; ++s)
86 {
87 float bx = *outputPtr++;
88 float by = *outputPtr++;
89 float bw = *outputPtr++;
90 float bh = *outputPtr++;
91
92 boxes[y][x][s][0] = ((bx + static_cast<float>(x)) / 7.0f) * imageWidthAsFloat;
93 boxes[y][x][s][1] = ((by + static_cast<float>(y)) / 7.0f) * imageHeightAsFloat;
94 boxes[y][x][s][2] = bw * bw * static_cast<float>(imageWidthAsFloat);
95 boxes[y][x][s][3] = bh * bh * static_cast<float>(imageHeightAsFloat);
96 }
97 }
98 }
99 BOOST_ASSERT(output.data() + YoloOutputSize == outputPtr);
100
101 std::vector<YoloDetectedObject> detectedObjects;
102 detectedObjects.reserve(gridSize * gridSize * numScales * numClasses);
103
104 for (Boost3dArray::index y = 0; y < gridSize; ++y)
105 {
106 for (Boost3dArray::index x = 0; x < gridSize; ++x)
107 {
108 for (Boost3dArray::index s = 0; s < numScales; ++s)
109 {
110 for (Boost3dArray::index c = 0; c < numClasses; ++c)
111 {
telsoa01c577f2c2018-08-31 09:22:23 +0100112 // Resolved confidence: class probabilities * scales.
telsoa014fcda012018-03-09 14:13:49 +0000113 const float confidence = classProbabilities[y][x][c] * scales[y][x][s];
114
telsoa01c577f2c2018-08-31 09:22:23 +0100115 // Resolves bounding box and stores.
telsoa014fcda012018-03-09 14:13:49 +0000116 YoloBoundingBox box;
117 box.m_X = boxes[y][x][s][0];
118 box.m_Y = boxes[y][x][s][1];
119 box.m_W = boxes[y][x][s][2];
120 box.m_H = boxes[y][x][s][3];
121
122 detectedObjects.emplace_back(c, box, confidence);
123 }
124 }
125 }
126 }
127
telsoa01c577f2c2018-08-31 09:22:23 +0100128 // Sorts detected objects by confidence.
telsoa014fcda012018-03-09 14:13:49 +0000129 std::sort(detectedObjects.begin(), detectedObjects.end(),
130 [](const YoloDetectedObject& a, const YoloDetectedObject& b)
131 {
telsoa01c577f2c2018-08-31 09:22:23 +0100132 // Sorts by largest confidence first, then by class.
telsoa014fcda012018-03-09 14:13:49 +0000133 return a.m_Confidence > b.m_Confidence
134 || (a.m_Confidence == b.m_Confidence && a.m_Class > b.m_Class);
135 });
136
telsoa01c577f2c2018-08-31 09:22:23 +0100137 // Checks the top N detections.
telsoa014fcda012018-03-09 14:13:49 +0000138 auto outputIt = detectedObjects.begin();
139 auto outputEnd = detectedObjects.end();
140
141 for (const YoloDetectedObject& expectedDetection : m_TopObjectDetections)
142 {
143 if (outputIt == outputEnd)
144 {
telsoa01c577f2c2018-08-31 09:22:23 +0100145 // Somehow expected more things to check than detections found by the model.
telsoa014fcda012018-03-09 14:13:49 +0000146 return TestCaseResult::Abort;
147 }
148
149 const YoloDetectedObject& detectedObject = *outputIt;
150 if (detectedObject.m_Class != expectedDetection.m_Class)
151 {
Derek Lamberti08446972019-11-26 16:38:31 +0000152 ARMNN_LOG(error) << "Prediction for test case " << this->GetTestCaseId() <<
James Conroyca225f02018-09-18 17:06:44 +0100153 " is incorrect: Expected (" << expectedDetection.m_Class << ")" <<
154 " but predicted (" << detectedObject.m_Class << ")";
telsoa014fcda012018-03-09 14:13:49 +0000155 return TestCaseResult::Failed;
156 }
157
158 if (!m_FloatComparer(detectedObject.m_Box.m_X, expectedDetection.m_Box.m_X) ||
159 !m_FloatComparer(detectedObject.m_Box.m_Y, expectedDetection.m_Box.m_Y) ||
160 !m_FloatComparer(detectedObject.m_Box.m_W, expectedDetection.m_Box.m_W) ||
161 !m_FloatComparer(detectedObject.m_Box.m_H, expectedDetection.m_Box.m_H) ||
162 !m_FloatComparer(detectedObject.m_Confidence, expectedDetection.m_Confidence))
163 {
Derek Lamberti08446972019-11-26 16:38:31 +0000164 ARMNN_LOG(error) << "Detected bounding box for test case " << this->GetTestCaseId() <<
telsoa014fcda012018-03-09 14:13:49 +0000165 " is incorrect";
166 return TestCaseResult::Failed;
167 }
168
169 ++outputIt;
170 }
171
172 return TestCaseResult::Ok;
173 }
174
175private:
176 boost::math::fpc::close_at_tolerance<float> m_FloatComparer;
177 std::vector<YoloDetectedObject> m_TopObjectDetections;
178};
179
180template <typename Model>
181class YoloTestCaseProvider : public IInferenceTestCaseProvider
182{
183public:
184 template <typename TConstructModelCallable>
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000185 explicit YoloTestCaseProvider(TConstructModelCallable constructModel)
telsoa014fcda012018-03-09 14:13:49 +0000186 : m_ConstructModel(constructModel)
187 {
188 }
189
190 virtual void AddCommandLineOptions(boost::program_options::options_description& options) override
191 {
192 namespace po = boost::program_options;
193
194 options.add_options()
195 ("data-dir,d", po::value<std::string>(&m_DataDir)->required(),
196 "Path to directory containing test data");
197
198 Model::AddCommandLineOptions(options, m_ModelCommandLineOptions);
199 }
200
Matthew Bentham3e68b972019-04-09 13:10:46 +0100201 virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions) override
telsoa014fcda012018-03-09 14:13:49 +0000202 {
203 if (!ValidateDirectory(m_DataDir))
204 {
205 return false;
206 }
207
Matthew Bentham3e68b972019-04-09 13:10:46 +0100208 m_Model = m_ConstructModel(commonOptions, m_ModelCommandLineOptions);
telsoa014fcda012018-03-09 14:13:49 +0000209 if (!m_Model)
210 {
211 return false;
212 }
213
214 m_Database = std::make_unique<YoloDatabase>(m_DataDir.c_str());
215 if (!m_Database)
216 {
217 return false;
218 }
219
220 return true;
221 }
222
223 virtual std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) override
224 {
225 std::unique_ptr<YoloTestCaseData> testCaseData = m_Database->GetTestCaseData(testCaseId);
226 if (!testCaseData)
227 {
228 return nullptr;
229 }
230
231 return std::make_unique<YoloTestCase<Model>>(*m_Model, testCaseId, *testCaseData);
232 }
233
234private:
235 typename Model::CommandLineOptions m_ModelCommandLineOptions;
Matthew Bentham3e68b972019-04-09 13:10:46 +0100236 std::function<std::unique_ptr<Model>(const InferenceTestOptions&,
237 typename Model::CommandLineOptions)> m_ConstructModel;
telsoa014fcda012018-03-09 14:13:49 +0000238 std::unique_ptr<Model> m_Model;
239
240 std::string m_DataDir;
241 std::unique_ptr<YoloDatabase> m_Database;
242};