blob: 64409d690458d4af34e2d974af60875957bbc377 [file] [log] [blame]
Éanna Ó Catháina4247d52019-05-08 14:00:45 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Éanna Ó Catháina4247d52019-05-08 14:00:45 +01008#include <algorithm>
SiCong Li898a3242019-06-24 16:03:33 +01009#include <armnn/Types.hpp>
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010010#include <armnn/utility/Assert.hpp>
James Ward6d9f5c52020-09-28 11:56:35 +010011#include <mapbox/variant.hpp>
SiCong Li898a3242019-06-24 16:03:33 +010012#include <cstddef>
13#include <functional>
14#include <iostream>
15#include <map>
16#include <string>
17#include <vector>
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010018
19namespace armnnUtils
20{
21
22using namespace armnn;
23
SiCong Li898a3242019-06-24 16:03:33 +010024// Category names associated with a label
25using LabelCategoryNames = std::vector<std::string>;
26
27/** Split a string into tokens by a delimiter
28 *
29 * @param[in] originalString Original string to be split
30 * @param[in] delimiter Delimiter used to split \p originalString
31 * @param[in] includeEmptyToekn If true, include empty tokens in the result
32 * @return A vector of tokens split from \p originalString by \delimiter
33 */
34std::vector<std::string>
35 SplitBy(const std::string& originalString, const std::string& delimiter = " ", bool includeEmptyToken = false);
36
37/** Remove any preceding and trailing character specified in the characterSet.
38 *
39 * @param[in] originalString Original string to be stripped
40 * @param[in] characterSet Set of characters to be stripped from \p originalString
41 * @return A string stripped of all characters specified in \p characterSet from \p originalString
42 */
43std::string Strip(const std::string& originalString, const std::string& characterSet = " ");
44
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010045class ModelAccuracyChecker
46{
47public:
SiCong Li898a3242019-06-24 16:03:33 +010048 /** Constructor for a model top k accuracy checker
49 *
50 * @param[in] validationLabelSet Mapping from names of images to be validated, to category names of their
51 corresponding ground-truth labels.
52 * @param[in] modelOutputLabels Mapping from output nodes to the category names of their corresponding labels
53 Note that an output node can have multiple category names.
54 */
55 ModelAccuracyChecker(const std::map<std::string, std::string>& validationLabelSet,
56 const std::vector<LabelCategoryNames>& modelOutputLabels);
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010057
SiCong Li898a3242019-06-24 16:03:33 +010058 /** Get Top K accuracy
59 *
60 * @param[in] k The number of top predictions to use for validating the ground-truth label. For example, if \p k is
61 3, then a prediction is considered correct as long as the ground-truth appears in the top 3
62 predictions.
63 * @return The accuracy, according to the top \p k th predictions.
64 */
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010065 float GetAccuracy(unsigned int k);
66
SiCong Li898a3242019-06-24 16:03:33 +010067 /** Record the prediction result of an image
68 *
69 * @param[in] imageName Name of the image.
70 * @param[in] outputTensor Output tensor of the network running \p imageName.
71 */
72 template <typename TContainer>
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010073 void AddImageResult(const std::string& imageName, std::vector<TContainer> outputTensor)
74 {
75 // Increment the total number of images processed
76 ++m_ImagesProcessed;
77
78 std::map<int, float> confidenceMap;
SiCong Li898a3242019-06-24 16:03:33 +010079 auto& output = outputTensor[0];
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010080
81 // Create a map of all predictions
James Ward6d9f5c52020-09-28 11:56:35 +010082 mapbox::util::apply_visitor([&confidenceMap](auto && value)
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010083 {
84 int index = 0;
85 for (const auto & o : value)
86 {
87 if (o > 0)
88 {
89 confidenceMap.insert(std::pair<int, float>(index, static_cast<float>(o)));
90 }
91 ++index;
92 }
93 },
94 output);
95
96 // Create a comparator for sorting the map in order of highest probability
97 typedef std::function<bool(std::pair<int, float>, std::pair<int, float>)> Comparator;
98
99 Comparator compFunctor =
100 [](std::pair<int, float> element1, std::pair<int, float> element2)
101 {
102 return element1.second > element2.second;
103 };
104
105 // Do the sorting and store in an ordered set
106 std::set<std::pair<int, float>, Comparator> setOfPredictions(
107 confidenceMap.begin(), confidenceMap.end(), compFunctor);
108
SiCong Li898a3242019-06-24 16:03:33 +0100109 const std::string correctLabel = m_GroundTruthLabelSet.at(imageName);
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100110
111 unsigned int index = 1;
112 for (std::pair<int, float> element : setOfPredictions)
113 {
SiCong Lie4403292019-06-19 18:07:25 +0100114 if (index >= m_TopK.size())
115 {
116 break;
117 }
SiCong Li898a3242019-06-24 16:03:33 +0100118 // Check if the ground truth label value is included in the topi prediction.
119 // Note that a prediction can have multiple prediction labels.
120 const LabelCategoryNames predictionLabels = m_ModelOutputLabels[static_cast<size_t>(element.first)];
121 if (std::find(predictionLabels.begin(), predictionLabels.end(), correctLabel) != predictionLabels.end())
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100122 {
123 ++m_TopK[index];
SiCong Lie4403292019-06-19 18:07:25 +0100124 break;
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100125 }
SiCong Lie4403292019-06-19 18:07:25 +0100126 ++index;
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100127 }
128 }
129
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100130private:
SiCong Li898a3242019-06-24 16:03:33 +0100131 const std::map<std::string, std::string> m_GroundTruthLabelSet;
132 const std::vector<LabelCategoryNames> m_ModelOutputLabels;
133 std::vector<unsigned int> m_TopK = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
134 unsigned int m_ImagesProcessed = 0;
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100135};
136} //namespace armnnUtils
137