blob: 1629de1bf415b0e5b1d3c6f6b6050c26a022a0ef [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
6
7#include "InferenceTest.hpp"
8#include "YoloDatabase.hpp"
9
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010010#include <armnn/utility/Assert.hpp>
Jan Eilers8eb25602020-03-09 12:13:48 +000011#include <armnn/utility/IgnoreUnused.hpp>
Colm Donelan9a5ce4a2020-10-29 11:39:14 +000012#include <armnnUtils/FloatingPointComparison.hpp>
Jan Eilers8eb25602020-03-09 12:13:48 +000013
telsoa014fcda012018-03-09 14:13:49 +000014#include <algorithm>
15#include <array>
16#include <utility>
17
telsoa014fcda012018-03-09 14:13:49 +000018constexpr size_t YoloOutputSize = 1470;
19
20template <typename Model>
21class YoloTestCase : public InferenceModelTestCase<Model>
22{
23public:
24 YoloTestCase(Model& model,
25 unsigned int testCaseId,
26 YoloTestCaseData& testCaseData)
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +000027 : InferenceModelTestCase<Model>(model, testCaseId, { std::move(testCaseData.m_InputImage) }, { YoloOutputSize })
telsoa014fcda012018-03-09 14:13:49 +000028 , m_TopObjectDetections(std::move(testCaseData.m_TopObjectDetections))
29 {
30 }
31
32 virtual TestCaseResult ProcessResult(const InferenceTestOptions& options) override
33 {
Jan Eilers8eb25602020-03-09 12:13:48 +000034 armnn::IgnoreUnused(options);
Derek Lambertieb1fce02019-12-10 21:20:10 +000035
James Ward6d9f5c52020-09-28 11:56:35 +010036 const std::vector<float>& output = mapbox::util::get<std::vector<float>>(this->GetOutputs()[0]);
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010037 ARMNN_ASSERT(output.size() == YoloOutputSize);
telsoa014fcda012018-03-09 14:13:49 +000038
Colm Donelan25d80ee2020-10-30 14:46:21 +000039 constexpr unsigned int gridSize = 7;
40 constexpr unsigned int numClasses = 20;
41 constexpr unsigned int numScales = 2;
telsoa014fcda012018-03-09 14:13:49 +000042
43 const float* outputPtr = output.data();
44
45 // Range 0-980. Class probabilities. 7x7x20
Colm Donelan25d80ee2020-10-30 14:46:21 +000046 vector<vector<vector<float>>> classProbabilities(gridSize, vector<vector<float>>(gridSize,
47 vector<float>(numClasses)));
48 for (unsigned int y = 0; y < gridSize; ++y)
telsoa014fcda012018-03-09 14:13:49 +000049 {
Colm Donelan25d80ee2020-10-30 14:46:21 +000050 for (unsigned int x = 0; x < gridSize; ++x)
telsoa014fcda012018-03-09 14:13:49 +000051 {
Colm Donelan25d80ee2020-10-30 14:46:21 +000052 for (unsigned int c = 0; c < numClasses; ++c)
telsoa014fcda012018-03-09 14:13:49 +000053 {
54 classProbabilities[y][x][c] = *outputPtr++;
55 }
56 }
57 }
58
59 // Range 980-1078. Scales. 7x7x2
Colm Donelan25d80ee2020-10-30 14:46:21 +000060 vector<vector<vector<float>>> scales(gridSize, vector<vector<float>>(gridSize, vector<float>(numScales)));
61 for (unsigned int y = 0; y < gridSize; ++y)
telsoa014fcda012018-03-09 14:13:49 +000062 {
Colm Donelan25d80ee2020-10-30 14:46:21 +000063 for (unsigned int x = 0; x < gridSize; ++x)
telsoa014fcda012018-03-09 14:13:49 +000064 {
Colm Donelan25d80ee2020-10-30 14:46:21 +000065 for (unsigned int s = 0; s < numScales; ++s)
telsoa014fcda012018-03-09 14:13:49 +000066 {
67 scales[y][x][s] = *outputPtr++;
68 }
69 }
70 }
71
72 // Range 1078-1469. Bounding boxes. 7x7x2x4
73 constexpr float imageWidthAsFloat = static_cast<float>(YoloImageWidth);
74 constexpr float imageHeightAsFloat = static_cast<float>(YoloImageHeight);
75
Colm Donelan25d80ee2020-10-30 14:46:21 +000076 vector<vector<vector<vector<float>>>> boxes(gridSize, vector<vector<vector<float>>>
77 (gridSize, vector<vector<float>>(numScales, vector<float>(4))));
78 for (unsigned int y = 0; y < gridSize; ++y)
telsoa014fcda012018-03-09 14:13:49 +000079 {
Colm Donelan25d80ee2020-10-30 14:46:21 +000080 for (unsigned int x = 0; x < gridSize; ++x)
telsoa014fcda012018-03-09 14:13:49 +000081 {
Colm Donelan25d80ee2020-10-30 14:46:21 +000082 for (unsigned int s = 0; s < numScales; ++s)
telsoa014fcda012018-03-09 14:13:49 +000083 {
84 float bx = *outputPtr++;
85 float by = *outputPtr++;
86 float bw = *outputPtr++;
87 float bh = *outputPtr++;
88
89 boxes[y][x][s][0] = ((bx + static_cast<float>(x)) / 7.0f) * imageWidthAsFloat;
90 boxes[y][x][s][1] = ((by + static_cast<float>(y)) / 7.0f) * imageHeightAsFloat;
91 boxes[y][x][s][2] = bw * bw * static_cast<float>(imageWidthAsFloat);
92 boxes[y][x][s][3] = bh * bh * static_cast<float>(imageHeightAsFloat);
93 }
94 }
95 }
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010096 ARMNN_ASSERT(output.data() + YoloOutputSize == outputPtr);
telsoa014fcda012018-03-09 14:13:49 +000097
98 std::vector<YoloDetectedObject> detectedObjects;
99 detectedObjects.reserve(gridSize * gridSize * numScales * numClasses);
100
Colm Donelan25d80ee2020-10-30 14:46:21 +0000101 for (unsigned int y = 0; y < gridSize; ++y)
telsoa014fcda012018-03-09 14:13:49 +0000102 {
Colm Donelan25d80ee2020-10-30 14:46:21 +0000103 for (unsigned int x = 0; x < gridSize; ++x)
telsoa014fcda012018-03-09 14:13:49 +0000104 {
Colm Donelan25d80ee2020-10-30 14:46:21 +0000105 for (unsigned int s = 0; s < numScales; ++s)
telsoa014fcda012018-03-09 14:13:49 +0000106 {
Colm Donelan25d80ee2020-10-30 14:46:21 +0000107 for (unsigned int c = 0; c < numClasses; ++c)
telsoa014fcda012018-03-09 14:13:49 +0000108 {
telsoa01c577f2c2018-08-31 09:22:23 +0100109 // Resolved confidence: class probabilities * scales.
telsoa014fcda012018-03-09 14:13:49 +0000110 const float confidence = classProbabilities[y][x][c] * scales[y][x][s];
111
telsoa01c577f2c2018-08-31 09:22:23 +0100112 // Resolves bounding box and stores.
telsoa014fcda012018-03-09 14:13:49 +0000113 YoloBoundingBox box;
114 box.m_X = boxes[y][x][s][0];
115 box.m_Y = boxes[y][x][s][1];
116 box.m_W = boxes[y][x][s][2];
117 box.m_H = boxes[y][x][s][3];
118
119 detectedObjects.emplace_back(c, box, confidence);
120 }
121 }
122 }
123 }
124
telsoa01c577f2c2018-08-31 09:22:23 +0100125 // Sorts detected objects by confidence.
telsoa014fcda012018-03-09 14:13:49 +0000126 std::sort(detectedObjects.begin(), detectedObjects.end(),
127 [](const YoloDetectedObject& a, const YoloDetectedObject& b)
128 {
telsoa01c577f2c2018-08-31 09:22:23 +0100129 // Sorts by largest confidence first, then by class.
telsoa014fcda012018-03-09 14:13:49 +0000130 return a.m_Confidence > b.m_Confidence
131 || (a.m_Confidence == b.m_Confidence && a.m_Class > b.m_Class);
132 });
133
telsoa01c577f2c2018-08-31 09:22:23 +0100134 // Checks the top N detections.
telsoa014fcda012018-03-09 14:13:49 +0000135 auto outputIt = detectedObjects.begin();
136 auto outputEnd = detectedObjects.end();
137
138 for (const YoloDetectedObject& expectedDetection : m_TopObjectDetections)
139 {
140 if (outputIt == outputEnd)
141 {
telsoa01c577f2c2018-08-31 09:22:23 +0100142 // Somehow expected more things to check than detections found by the model.
telsoa014fcda012018-03-09 14:13:49 +0000143 return TestCaseResult::Abort;
144 }
145
146 const YoloDetectedObject& detectedObject = *outputIt;
147 if (detectedObject.m_Class != expectedDetection.m_Class)
148 {
Derek Lamberti08446972019-11-26 16:38:31 +0000149 ARMNN_LOG(error) << "Prediction for test case " << this->GetTestCaseId() <<
James Conroyca225f02018-09-18 17:06:44 +0100150 " is incorrect: Expected (" << expectedDetection.m_Class << ")" <<
151 " but predicted (" << detectedObject.m_Class << ")";
telsoa014fcda012018-03-09 14:13:49 +0000152 return TestCaseResult::Failed;
153 }
154
Colm Donelan9a5ce4a2020-10-29 11:39:14 +0000155 if (!armnnUtils::within_percentage_tolerance(detectedObject.m_Box.m_X, expectedDetection.m_Box.m_X) ||
156 !armnnUtils::within_percentage_tolerance(detectedObject.m_Box.m_Y, expectedDetection.m_Box.m_Y) ||
157 !armnnUtils::within_percentage_tolerance(detectedObject.m_Box.m_W, expectedDetection.m_Box.m_W) ||
158 !armnnUtils::within_percentage_tolerance(detectedObject.m_Box.m_H, expectedDetection.m_Box.m_H) ||
159 !armnnUtils::within_percentage_tolerance(detectedObject.m_Confidence, expectedDetection.m_Confidence))
telsoa014fcda012018-03-09 14:13:49 +0000160 {
Derek Lamberti08446972019-11-26 16:38:31 +0000161 ARMNN_LOG(error) << "Detected bounding box for test case " << this->GetTestCaseId() <<
telsoa014fcda012018-03-09 14:13:49 +0000162 " is incorrect";
163 return TestCaseResult::Failed;
164 }
165
166 ++outputIt;
167 }
168
169 return TestCaseResult::Ok;
170 }
171
172private:
telsoa014fcda012018-03-09 14:13:49 +0000173 std::vector<YoloDetectedObject> m_TopObjectDetections;
174};
175
176template <typename Model>
177class YoloTestCaseProvider : public IInferenceTestCaseProvider
178{
179public:
180 template <typename TConstructModelCallable>
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000181 explicit YoloTestCaseProvider(TConstructModelCallable constructModel)
telsoa014fcda012018-03-09 14:13:49 +0000182 : m_ConstructModel(constructModel)
183 {
184 }
185
James Wardc89829f2020-10-12 14:17:36 +0100186 virtual void AddCommandLineOptions(cxxopts::Options& options, std::vector<std::string>& required) override
telsoa014fcda012018-03-09 14:13:49 +0000187 {
James Wardc89829f2020-10-12 14:17:36 +0100188 options
189 .allow_unrecognised_options()
190 .add_options()
191 ("d,data-dir", "Path to directory containing test data", cxxopts::value<std::string>(m_DataDir));
telsoa014fcda012018-03-09 14:13:49 +0000192
James Wardc89829f2020-10-12 14:17:36 +0100193 Model::AddCommandLineOptions(options, m_ModelCommandLineOptions, required);
telsoa014fcda012018-03-09 14:13:49 +0000194 }
195
James Wardc89829f2020-10-12 14:17:36 +0100196 virtual bool ProcessCommandLineOptions(const InferenceTestOptions& commonOptions) override
telsoa014fcda012018-03-09 14:13:49 +0000197 {
198 if (!ValidateDirectory(m_DataDir))
199 {
200 return false;
201 }
202
Matthew Bentham3e68b972019-04-09 13:10:46 +0100203 m_Model = m_ConstructModel(commonOptions, m_ModelCommandLineOptions);
telsoa014fcda012018-03-09 14:13:49 +0000204 if (!m_Model)
205 {
206 return false;
207 }
208
209 m_Database = std::make_unique<YoloDatabase>(m_DataDir.c_str());
210 if (!m_Database)
211 {
212 return false;
213 }
214
215 return true;
216 }
217
218 virtual std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) override
219 {
220 std::unique_ptr<YoloTestCaseData> testCaseData = m_Database->GetTestCaseData(testCaseId);
221 if (!testCaseData)
222 {
223 return nullptr;
224 }
225
226 return std::make_unique<YoloTestCase<Model>>(*m_Model, testCaseId, *testCaseData);
227 }
228
229private:
230 typename Model::CommandLineOptions m_ModelCommandLineOptions;
Matthew Bentham3e68b972019-04-09 13:10:46 +0100231 std::function<std::unique_ptr<Model>(const InferenceTestOptions&,
232 typename Model::CommandLineOptions)> m_ConstructModel;
telsoa014fcda012018-03-09 14:13:49 +0000233 std::unique_ptr<Model> m_Model;
234
235 std::string m_DataDir;
236 std::unique_ptr<YoloDatabase> m_Database;
237};