blob: cdd2af0ac5efe15f79fbf1b4acef3b4559f2d23c [file] [log] [blame]
Éanna Ó Catháina4247d52019-05-08 14:00:45 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <cstddef>
9#include <string>
10#include <map>
11#include <vector>
12#include <boost/variant/apply_visitor.hpp>
13#include <iostream>
14#include <armnn/Types.hpp>
15#include <functional>
16#include <algorithm>
17
18namespace armnnUtils
19{
20
21using namespace armnn;
22
23class ModelAccuracyChecker
24{
25public:
26 ModelAccuracyChecker(const std::map<std::string, int>& validationLabelSet);
27
28 float GetAccuracy(unsigned int k);
29
30 template<typename TContainer>
31 void AddImageResult(const std::string& imageName, std::vector<TContainer> outputTensor)
32 {
33 // Increment the total number of images processed
34 ++m_ImagesProcessed;
35
36 std::map<int, float> confidenceMap;
37 auto & output = outputTensor[0];
38
39 // Create a map of all predictions
40 boost::apply_visitor([&](auto && value)
41 {
42 int index = 0;
43 for (const auto & o : value)
44 {
45 if (o > 0)
46 {
47 confidenceMap.insert(std::pair<int, float>(index, static_cast<float>(o)));
48 }
49 ++index;
50 }
51 },
52 output);
53
54 // Create a comparator for sorting the map in order of highest probability
55 typedef std::function<bool(std::pair<int, float>, std::pair<int, float>)> Comparator;
56
57 Comparator compFunctor =
58 [](std::pair<int, float> element1, std::pair<int, float> element2)
59 {
60 return element1.second > element2.second;
61 };
62
63 // Do the sorting and store in an ordered set
64 std::set<std::pair<int, float>, Comparator> setOfPredictions(
65 confidenceMap.begin(), confidenceMap.end(), compFunctor);
66
67 std::string trimmedName = GetTrimmedImageName(imageName);
68 int value = m_GroundTruthLabelSet.find(trimmedName)->second;
69
70 unsigned int index = 1;
71 for (std::pair<int, float> element : setOfPredictions)
72 {
SiCong Lie4403292019-06-19 18:07:25 +010073 if (index >= m_TopK.size())
74 {
75 break;
76 }
77 if (element.first == value)
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010078 {
79 ++m_TopK[index];
SiCong Lie4403292019-06-19 18:07:25 +010080 break;
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010081 }
SiCong Lie4403292019-06-19 18:07:25 +010082 ++index;
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010083 }
84 }
85
86 std::string GetTrimmedImageName(const std::string& imageName) const
87 {
88 std::string trimmedName;
89 size_t lastindex = imageName.find_last_of(".");
90 if(lastindex != std::string::npos)
91 {
92 trimmedName = imageName.substr(0, lastindex);
93 } else
94 {
95 trimmedName = imageName;
96 }
97 return trimmedName;
98 }
99
100private:
101 const std::map<std::string, int> m_GroundTruthLabelSet;
102 std::vector<unsigned int> m_TopK = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
103 unsigned int m_ImagesProcessed = 0;
104};
105} //namespace armnnUtils
106