blob: abf994b5e14377ca7fa40aaca57fa3e801f2ca21 [file] [log] [blame]
Éanna Ó Catháina4247d52019-05-08 14:00:45 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <cstddef>
9#include <string>
10#include <map>
11#include <vector>
12#include <boost/variant/apply_visitor.hpp>
13#include <iostream>
14#include <armnn/Types.hpp>
15#include <functional>
16#include <algorithm>
17
18namespace armnnUtils
19{
20
21using namespace armnn;
22
23class ModelAccuracyChecker
24{
25public:
26 ModelAccuracyChecker(const std::map<std::string, int>& validationLabelSet);
27
28 float GetAccuracy(unsigned int k);
29
30 template<typename TContainer>
31 void AddImageResult(const std::string& imageName, std::vector<TContainer> outputTensor)
32 {
33 // Increment the total number of images processed
34 ++m_ImagesProcessed;
35
36 std::map<int, float> confidenceMap;
37 auto & output = outputTensor[0];
38
39 // Create a map of all predictions
40 boost::apply_visitor([&](auto && value)
41 {
42 int index = 0;
43 for (const auto & o : value)
44 {
45 if (o > 0)
46 {
47 confidenceMap.insert(std::pair<int, float>(index, static_cast<float>(o)));
48 }
49 ++index;
50 }
51 },
52 output);
53
54 // Create a comparator for sorting the map in order of highest probability
55 typedef std::function<bool(std::pair<int, float>, std::pair<int, float>)> Comparator;
56
57 Comparator compFunctor =
58 [](std::pair<int, float> element1, std::pair<int, float> element2)
59 {
60 return element1.second > element2.second;
61 };
62
63 // Do the sorting and store in an ordered set
64 std::set<std::pair<int, float>, Comparator> setOfPredictions(
65 confidenceMap.begin(), confidenceMap.end(), compFunctor);
66
67 std::string trimmedName = GetTrimmedImageName(imageName);
68 int value = m_GroundTruthLabelSet.find(trimmedName)->second;
69
70 unsigned int index = 1;
71 for (std::pair<int, float> element : setOfPredictions)
72 {
73 if(element.first == value)
74 {
75 ++m_TopK[index];
76 } else
77 {
78 ++index;
79 }
80 }
81 }
82
83 std::string GetTrimmedImageName(const std::string& imageName) const
84 {
85 std::string trimmedName;
86 size_t lastindex = imageName.find_last_of(".");
87 if(lastindex != std::string::npos)
88 {
89 trimmedName = imageName.substr(0, lastindex);
90 } else
91 {
92 trimmedName = imageName;
93 }
94 return trimmedName;
95 }
96
97private:
98 const std::map<std::string, int> m_GroundTruthLabelSet;
99 std::vector<unsigned int> m_TopK = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
100 unsigned int m_ImagesProcessed = 0;
101};
102} //namespace armnnUtils
103