blob: d641c22ffb5a73ec3fd56fddec265873a460c7d8 [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001/*
2 * Copyright (c) 2021 Arm Limited. All rights reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#ifndef CLASSIFIER_HPP
18#define CLASSIFIER_HPP
19
20#include "ClassificationResult.hpp"
21#include "TensorFlowLiteMicro.hpp"
22
23#include <vector>
24
25namespace arm {
26namespace app {
27
28 /**
29 * @brief Classifier - a helper class to get certain number of top
30 * results from the output vector from a classification NN.
31 **/
32 class Classifier{
33 public:
34 /** @brief Constructor. */
35 Classifier() = default;
36
37 /**
38 * @brief Gets the top N classification results from the
39 * output vector.
40 * @param[in] outputTensor Inference output tensor from an NN model.
41 * @param[out] vecResults A vector of classification results.
42 * populated by this function.
43 * @param[in] labels Labels vector to match classified classes.
44 * @param[in] topNCount Number of top classifications to pick. Default is 1.
Kshitij Sisodia76a15802021-12-24 11:05:11 +000045 * @param[in] useSoftmax Whether Softmax normalisation should be applied to output. Default is false.
alexander3c798932021-03-26 21:42:19 +000046 * @return true if successful, false otherwise.
47 **/
Kshitij Sisodia76a15802021-12-24 11:05:11 +000048
alexander3c798932021-03-26 21:42:19 +000049 virtual bool GetClassificationResults(
50 TfLiteTensor* outputTensor,
51 std::vector<ClassificationResult>& vecResults,
Kshitij Sisodia76a15802021-12-24 11:05:11 +000052 const std::vector <std::string>& labels, uint32_t topNCount,
alexander31ae9f02022-02-10 16:15:54 +000053 bool use_softmax);
Kshitij Sisodia76a15802021-12-24 11:05:11 +000054
55 /**
56 * @brief Populate the elements of the Classification Result object.
57 * @param[in] topNSet Ordered set of top 5 output class scores and labels.
58 * @param[out] vecResults A vector of classification results.
59 * populated by this function.
60 * @param[in] labels Labels vector to match classified classes.
61 **/
62
63 void SetVectorResults(
64 std::set<std::pair<float, uint32_t>>& topNSet,
65 std::vector<ClassificationResult>& vecResults,
66 const std::vector <std::string>& labels);
alexander3c798932021-03-26 21:42:19 +000067
68 private:
69 /**
70 * @brief Utility function that gets the top N classification results from the
71 * output vector.
alexander3c798932021-03-26 21:42:19 +000072 * @param[in] tensor Inference output tensor from an NN model.
73 * @param[out] vecResults A vector of classification results
74 * populated by this function.
75 * @param[in] topNCount Number of top classifications to pick.
76 * @param[in] labels Labels vector to match classified classes.
77 * @return true if successful, false otherwise.
78 **/
Kshitij Sisodia76a15802021-12-24 11:05:11 +000079
80 bool GetTopNResults(const std::vector<float>& tensor,
alexander3c798932021-03-26 21:42:19 +000081 std::vector<ClassificationResult>& vecResults,
82 uint32_t topNCount,
83 const std::vector <std::string>& labels);
84 };
85
86} /* namespace app */
87} /* namespace arm */
88
89#endif /* CLASSIFIER_HPP */