blob: cb1817a0bb8bf3a9342d9a647df5ad5ef9782173 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
6
7#include "InferenceTest.hpp"
8#include "YoloDatabase.hpp"
9
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010010#include <armnn/utility/Assert.hpp>
Jan Eilers8eb25602020-03-09 12:13:48 +000011#include <armnn/utility/IgnoreUnused.hpp>
Colm Donelan9a5ce4a2020-10-29 11:39:14 +000012#include <armnnUtils/FloatingPointComparison.hpp>
Jan Eilers8eb25602020-03-09 12:13:48 +000013
James Wardc89829f2020-10-12 14:17:36 +010014#include <boost/multi_array.hpp>
telsoa014fcda012018-03-09 14:13:49 +000015#include <algorithm>
16#include <array>
17#include <utility>
18
telsoa014fcda012018-03-09 14:13:49 +000019constexpr size_t YoloOutputSize = 1470;
20
21template <typename Model>
22class YoloTestCase : public InferenceModelTestCase<Model>
23{
24public:
25 YoloTestCase(Model& model,
26 unsigned int testCaseId,
27 YoloTestCaseData& testCaseData)
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +000028 : InferenceModelTestCase<Model>(model, testCaseId, { std::move(testCaseData.m_InputImage) }, { YoloOutputSize })
telsoa014fcda012018-03-09 14:13:49 +000029 , m_TopObjectDetections(std::move(testCaseData.m_TopObjectDetections))
30 {
31 }
32
33 virtual TestCaseResult ProcessResult(const InferenceTestOptions& options) override
34 {
Jan Eilers8eb25602020-03-09 12:13:48 +000035 armnn::IgnoreUnused(options);
Derek Lambertieb1fce02019-12-10 21:20:10 +000036
telsoa014fcda012018-03-09 14:13:49 +000037 using Boost3dArray = boost::multi_array<float, 3>;
38
James Ward6d9f5c52020-09-28 11:56:35 +010039 const std::vector<float>& output = mapbox::util::get<std::vector<float>>(this->GetOutputs()[0]);
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010040 ARMNN_ASSERT(output.size() == YoloOutputSize);
telsoa014fcda012018-03-09 14:13:49 +000041
42 constexpr Boost3dArray::index gridSize = 7;
43 constexpr Boost3dArray::index numClasses = 20;
44 constexpr Boost3dArray::index numScales = 2;
45
46 const float* outputPtr = output.data();
47
48 // Range 0-980. Class probabilities. 7x7x20
49 Boost3dArray classProbabilities(boost::extents[gridSize][gridSize][numClasses]);
50 for (Boost3dArray::index y = 0; y < gridSize; ++y)
51 {
52 for (Boost3dArray::index x = 0; x < gridSize; ++x)
53 {
54 for (Boost3dArray::index c = 0; c < numClasses; ++c)
55 {
56 classProbabilities[y][x][c] = *outputPtr++;
57 }
58 }
59 }
60
61 // Range 980-1078. Scales. 7x7x2
62 Boost3dArray scales(boost::extents[gridSize][gridSize][numScales]);
63 for (Boost3dArray::index y = 0; y < gridSize; ++y)
64 {
65 for (Boost3dArray::index x = 0; x < gridSize; ++x)
66 {
67 for (Boost3dArray::index s = 0; s < numScales; ++s)
68 {
69 scales[y][x][s] = *outputPtr++;
70 }
71 }
72 }
73
74 // Range 1078-1469. Bounding boxes. 7x7x2x4
75 constexpr float imageWidthAsFloat = static_cast<float>(YoloImageWidth);
76 constexpr float imageHeightAsFloat = static_cast<float>(YoloImageHeight);
77
78 boost::multi_array<float, 4> boxes(boost::extents[gridSize][gridSize][numScales][4]);
79 for (Boost3dArray::index y = 0; y < gridSize; ++y)
80 {
81 for (Boost3dArray::index x = 0; x < gridSize; ++x)
82 {
83 for (Boost3dArray::index s = 0; s < numScales; ++s)
84 {
85 float bx = *outputPtr++;
86 float by = *outputPtr++;
87 float bw = *outputPtr++;
88 float bh = *outputPtr++;
89
90 boxes[y][x][s][0] = ((bx + static_cast<float>(x)) / 7.0f) * imageWidthAsFloat;
91 boxes[y][x][s][1] = ((by + static_cast<float>(y)) / 7.0f) * imageHeightAsFloat;
92 boxes[y][x][s][2] = bw * bw * static_cast<float>(imageWidthAsFloat);
93 boxes[y][x][s][3] = bh * bh * static_cast<float>(imageHeightAsFloat);
94 }
95 }
96 }
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010097 ARMNN_ASSERT(output.data() + YoloOutputSize == outputPtr);
telsoa014fcda012018-03-09 14:13:49 +000098
99 std::vector<YoloDetectedObject> detectedObjects;
100 detectedObjects.reserve(gridSize * gridSize * numScales * numClasses);
101
102 for (Boost3dArray::index y = 0; y < gridSize; ++y)
103 {
104 for (Boost3dArray::index x = 0; x < gridSize; ++x)
105 {
106 for (Boost3dArray::index s = 0; s < numScales; ++s)
107 {
108 for (Boost3dArray::index c = 0; c < numClasses; ++c)
109 {
telsoa01c577f2c2018-08-31 09:22:23 +0100110 // Resolved confidence: class probabilities * scales.
telsoa014fcda012018-03-09 14:13:49 +0000111 const float confidence = classProbabilities[y][x][c] * scales[y][x][s];
112
telsoa01c577f2c2018-08-31 09:22:23 +0100113 // Resolves bounding box and stores.
telsoa014fcda012018-03-09 14:13:49 +0000114 YoloBoundingBox box;
115 box.m_X = boxes[y][x][s][0];
116 box.m_Y = boxes[y][x][s][1];
117 box.m_W = boxes[y][x][s][2];
118 box.m_H = boxes[y][x][s][3];
119
120 detectedObjects.emplace_back(c, box, confidence);
121 }
122 }
123 }
124 }
125
telsoa01c577f2c2018-08-31 09:22:23 +0100126 // Sorts detected objects by confidence.
telsoa014fcda012018-03-09 14:13:49 +0000127 std::sort(detectedObjects.begin(), detectedObjects.end(),
128 [](const YoloDetectedObject& a, const YoloDetectedObject& b)
129 {
telsoa01c577f2c2018-08-31 09:22:23 +0100130 // Sorts by largest confidence first, then by class.
telsoa014fcda012018-03-09 14:13:49 +0000131 return a.m_Confidence > b.m_Confidence
132 || (a.m_Confidence == b.m_Confidence && a.m_Class > b.m_Class);
133 });
134
telsoa01c577f2c2018-08-31 09:22:23 +0100135 // Checks the top N detections.
telsoa014fcda012018-03-09 14:13:49 +0000136 auto outputIt = detectedObjects.begin();
137 auto outputEnd = detectedObjects.end();
138
139 for (const YoloDetectedObject& expectedDetection : m_TopObjectDetections)
140 {
141 if (outputIt == outputEnd)
142 {
telsoa01c577f2c2018-08-31 09:22:23 +0100143 // Somehow expected more things to check than detections found by the model.
telsoa014fcda012018-03-09 14:13:49 +0000144 return TestCaseResult::Abort;
145 }
146
147 const YoloDetectedObject& detectedObject = *outputIt;
148 if (detectedObject.m_Class != expectedDetection.m_Class)
149 {
Derek Lamberti08446972019-11-26 16:38:31 +0000150 ARMNN_LOG(error) << "Prediction for test case " << this->GetTestCaseId() <<
James Conroyca225f02018-09-18 17:06:44 +0100151 " is incorrect: Expected (" << expectedDetection.m_Class << ")" <<
152 " but predicted (" << detectedObject.m_Class << ")";
telsoa014fcda012018-03-09 14:13:49 +0000153 return TestCaseResult::Failed;
154 }
155
Colm Donelan9a5ce4a2020-10-29 11:39:14 +0000156 if (!armnnUtils::within_percentage_tolerance(detectedObject.m_Box.m_X, expectedDetection.m_Box.m_X) ||
157 !armnnUtils::within_percentage_tolerance(detectedObject.m_Box.m_Y, expectedDetection.m_Box.m_Y) ||
158 !armnnUtils::within_percentage_tolerance(detectedObject.m_Box.m_W, expectedDetection.m_Box.m_W) ||
159 !armnnUtils::within_percentage_tolerance(detectedObject.m_Box.m_H, expectedDetection.m_Box.m_H) ||
160 !armnnUtils::within_percentage_tolerance(detectedObject.m_Confidence, expectedDetection.m_Confidence))
telsoa014fcda012018-03-09 14:13:49 +0000161 {
Derek Lamberti08446972019-11-26 16:38:31 +0000162 ARMNN_LOG(error) << "Detected bounding box for test case " << this->GetTestCaseId() <<
telsoa014fcda012018-03-09 14:13:49 +0000163 " is incorrect";
164 return TestCaseResult::Failed;
165 }
166
167 ++outputIt;
168 }
169
170 return TestCaseResult::Ok;
171 }
172
173private:
telsoa014fcda012018-03-09 14:13:49 +0000174 std::vector<YoloDetectedObject> m_TopObjectDetections;
175};
176
177template <typename Model>
178class YoloTestCaseProvider : public IInferenceTestCaseProvider
179{
180public:
181 template <typename TConstructModelCallable>
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000182 explicit YoloTestCaseProvider(TConstructModelCallable constructModel)
telsoa014fcda012018-03-09 14:13:49 +0000183 : m_ConstructModel(constructModel)
184 {
185 }
186
James Wardc89829f2020-10-12 14:17:36 +0100187 virtual void AddCommandLineOptions(cxxopts::Options& options, std::vector<std::string>& required) override
telsoa014fcda012018-03-09 14:13:49 +0000188 {
James Wardc89829f2020-10-12 14:17:36 +0100189 options
190 .allow_unrecognised_options()
191 .add_options()
192 ("d,data-dir", "Path to directory containing test data", cxxopts::value<std::string>(m_DataDir));
telsoa014fcda012018-03-09 14:13:49 +0000193
James Wardc89829f2020-10-12 14:17:36 +0100194 Model::AddCommandLineOptions(options, m_ModelCommandLineOptions, required);
telsoa014fcda012018-03-09 14:13:49 +0000195 }
196
James Wardc89829f2020-10-12 14:17:36 +0100197 virtual bool ProcessCommandLineOptions(const InferenceTestOptions& commonOptions) override
telsoa014fcda012018-03-09 14:13:49 +0000198 {
199 if (!ValidateDirectory(m_DataDir))
200 {
201 return false;
202 }
203
Matthew Bentham3e68b972019-04-09 13:10:46 +0100204 m_Model = m_ConstructModel(commonOptions, m_ModelCommandLineOptions);
telsoa014fcda012018-03-09 14:13:49 +0000205 if (!m_Model)
206 {
207 return false;
208 }
209
210 m_Database = std::make_unique<YoloDatabase>(m_DataDir.c_str());
211 if (!m_Database)
212 {
213 return false;
214 }
215
216 return true;
217 }
218
219 virtual std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) override
220 {
221 std::unique_ptr<YoloTestCaseData> testCaseData = m_Database->GetTestCaseData(testCaseId);
222 if (!testCaseData)
223 {
224 return nullptr;
225 }
226
227 return std::make_unique<YoloTestCase<Model>>(*m_Model, testCaseId, *testCaseData);
228 }
229
230private:
231 typename Model::CommandLineOptions m_ModelCommandLineOptions;
Matthew Bentham3e68b972019-04-09 13:10:46 +0100232 std::function<std::unique_ptr<Model>(const InferenceTestOptions&,
233 typename Model::CommandLineOptions)> m_ConstructModel;
telsoa014fcda012018-03-09 14:13:49 +0000234 std::unique_ptr<Model> m_Model;
235
236 std::string m_DataDir;
237 std::unique_ptr<YoloDatabase> m_Database;
238};