blob: 8e2cf47a4a1a2501a314257810c8bf4d236e2326 [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001/*
Richard Burtonf32a86a2022-11-15 11:46:11 +00002 * SPDX-FileCopyrightText: Copyright 2021-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
alexander3c798932021-03-26 21:42:19 +00003 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#ifndef CLASSIFIER_HPP
18#define CLASSIFIER_HPP
19
20#include "ClassificationResult.hpp"
21#include "TensorFlowLiteMicro.hpp"
22
23#include <vector>
24
25namespace arm {
26namespace app {
27
28 /**
29 * @brief Classifier - a helper class to get certain number of top
30 * results from the output vector from a classification NN.
31 **/
32 class Classifier{
33 public:
34 /** @brief Constructor. */
35 Classifier() = default;
36
Richard Burtonec5e99b2022-10-05 11:00:37 +010037 virtual ~Classifier() = default;
38
alexander3c798932021-03-26 21:42:19 +000039 /**
40 * @brief Gets the top N classification results from the
41 * output vector.
42 * @param[in] outputTensor Inference output tensor from an NN model.
43 * @param[out] vecResults A vector of classification results.
44 * populated by this function.
45 * @param[in] labels Labels vector to match classified classes.
Richard Burtonec5e99b2022-10-05 11:00:37 +010046 * @param[in] topNCount Number of top classifications to pick.
47 * @param[in] useSoftmax Whether Softmax normalisation should be applied to output.
alexander3c798932021-03-26 21:42:19 +000048 * @return true if successful, false otherwise.
49 **/
Kshitij Sisodia76a15802021-12-24 11:05:11 +000050
alexander3c798932021-03-26 21:42:19 +000051 virtual bool GetClassificationResults(
52 TfLiteTensor* outputTensor,
53 std::vector<ClassificationResult>& vecResults,
Kshitij Sisodia76a15802021-12-24 11:05:11 +000054 const std::vector <std::string>& labels, uint32_t topNCount,
alexander31ae9f02022-02-10 16:15:54 +000055 bool use_softmax);
Kshitij Sisodia76a15802021-12-24 11:05:11 +000056
57 /**
58 * @brief Populate the elements of the Classification Result object.
59 * @param[in] topNSet Ordered set of top 5 output class scores and labels.
60 * @param[out] vecResults A vector of classification results.
61 * populated by this function.
62 * @param[in] labels Labels vector to match classified classes.
63 **/
64
65 void SetVectorResults(
66 std::set<std::pair<float, uint32_t>>& topNSet,
67 std::vector<ClassificationResult>& vecResults,
68 const std::vector <std::string>& labels);
alexander3c798932021-03-26 21:42:19 +000069
Richard Burtonec5e99b2022-10-05 11:00:37 +010070 protected:
alexander3c798932021-03-26 21:42:19 +000071 /**
72 * @brief Utility function that gets the top N classification results from the
73 * output vector.
alexander3c798932021-03-26 21:42:19 +000074 * @param[in] tensor Inference output tensor from an NN model.
75 * @param[out] vecResults A vector of classification results
76 * populated by this function.
77 * @param[in] topNCount Number of top classifications to pick.
78 * @param[in] labels Labels vector to match classified classes.
79 * @return true if successful, false otherwise.
80 **/
Kshitij Sisodia76a15802021-12-24 11:05:11 +000081
82 bool GetTopNResults(const std::vector<float>& tensor,
alexander3c798932021-03-26 21:42:19 +000083 std::vector<ClassificationResult>& vecResults,
84 uint32_t topNCount,
85 const std::vector <std::string>& labels);
86 };
87
88} /* namespace app */
89} /* namespace arm */
90
91#endif /* CLASSIFIER_HPP */