blob: 510e6f9c49013d9f4145d11d131357c921d185e1 [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001/*
2 * Copyright (c) 2021 Arm Limited. All rights reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#ifndef CLASSIFIER_HPP
18#define CLASSIFIER_HPP
19
20#include "ClassificationResult.hpp"
21#include "TensorFlowLiteMicro.hpp"
22
23#include <vector>
24
25namespace arm {
26namespace app {
27
28 /**
29 * @brief Classifier - a helper class to get certain number of top
30 * results from the output vector from a classification NN.
31 **/
32 class Classifier{
33 public:
34 /** @brief Constructor. */
35 Classifier() = default;
36
37 /**
38 * @brief Gets the top N classification results from the
39 * output vector.
40 * @param[in] outputTensor Inference output tensor from an NN model.
41 * @param[out] vecResults A vector of classification results.
42 * populated by this function.
43 * @param[in] labels Labels vector to match classified classes.
44 * @param[in] topNCount Number of top classifications to pick. Default is 1.
45 * @return true if successful, false otherwise.
46 **/
47 virtual bool GetClassificationResults(
48 TfLiteTensor* outputTensor,
49 std::vector<ClassificationResult>& vecResults,
50 const std::vector <std::string>& labels, uint32_t topNCount);
51
52 private:
53 /**
54 * @brief Utility function that gets the top N classification results from the
55 * output vector.
56 * @tparam T value type
57 * @param[in] tensor Inference output tensor from an NN model.
58 * @param[out] vecResults A vector of classification results
59 * populated by this function.
60 * @param[in] topNCount Number of top classifications to pick.
61 * @param[in] labels Labels vector to match classified classes.
62 * @return true if successful, false otherwise.
63 **/
64 template<typename T>
65 bool _GetTopNResults(TfLiteTensor* tensor,
66 std::vector<ClassificationResult>& vecResults,
67 uint32_t topNCount,
68 const std::vector <std::string>& labels);
69 };
70
71} /* namespace app */
72} /* namespace arm */
73
74#endif /* CLASSIFIER_HPP */