blob: 81942dc2beb64cc677c8e87b19b42f43d1630a72 [file] [log] [blame]
Éanna Ó Catháina4247d52019-05-08 14:00:45 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
Éanna Ó Catháina4247d52019-05-08 14:00:45 +01006#include "ModelAccuracyChecker.hpp"
SiCong Li898a3242019-06-24 16:03:33 +01007#include <boost/filesystem.hpp>
8#include <boost/log/trivial.hpp>
9#include <map>
10#include <vector>
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010011
12namespace armnnUtils
13{
14
SiCong Li898a3242019-06-24 16:03:33 +010015armnnUtils::ModelAccuracyChecker::ModelAccuracyChecker(const std::map<std::string, std::string>& validationLabels,
16 const std::vector<LabelCategoryNames>& modelOutputLabels)
17 : m_GroundTruthLabelSet(validationLabels)
18 , m_ModelOutputLabels(modelOutputLabels)
19{}
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010020
21float ModelAccuracyChecker::GetAccuracy(unsigned int k)
22{
SiCong Li898a3242019-06-24 16:03:33 +010023 if (k > 10)
24 {
25 BOOST_LOG_TRIVIAL(warning) << "Accuracy Tool only supports a maximum of Top 10 Accuracy. "
26 "Printing Top 10 Accuracy result!";
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010027 k = 10;
28 }
29 unsigned int total = 0;
30 for (unsigned int i = k; i > 0; --i)
31 {
32 total += m_TopK[i];
33 }
34 return static_cast<float>(total * 100) / static_cast<float>(m_ImagesProcessed);
35}
SiCong Li898a3242019-06-24 16:03:33 +010036
37// Split a string into tokens by a delimiter
38std::vector<std::string>
39 SplitBy(const std::string& originalString, const std::string& delimiter, bool includeEmptyToken)
40{
41 std::vector<std::string> tokens;
42 size_t cur = 0;
43 size_t next = 0;
44 while ((next = originalString.find(delimiter, cur)) != std::string::npos)
45 {
46 // Skip empty tokens, unless explicitly stated to include them.
47 if (next - cur > 0 || includeEmptyToken)
48 {
49 tokens.push_back(originalString.substr(cur, next - cur));
50 }
51 cur = next + delimiter.size();
52 }
53 // Get the remaining token
54 // Skip empty tokens, unless explicitly stated to include them.
55 if (originalString.size() - cur > 0 || includeEmptyToken)
56 {
57 tokens.push_back(originalString.substr(cur, originalString.size() - cur));
58 }
59 return tokens;
60}
61
62// Remove any preceding and trailing character specified in the characterSet.
63std::string Strip(const std::string& originalString, const std::string& characterSet)
64{
65 BOOST_ASSERT(!characterSet.empty());
66 const std::size_t firstFound = originalString.find_first_not_of(characterSet);
67 const std::size_t lastFound = originalString.find_last_not_of(characterSet);
68 // Return empty if the originalString is empty or the originalString contains only to-be-striped characters
69 if (firstFound == std::string::npos || lastFound == std::string::npos)
70 {
71 return "";
72 }
73 return originalString.substr(firstFound, lastFound + 1 - firstFound);
74}
75} // namespace armnnUtils