blob: c5519fbb05cd91dd0bb7731b82b514a3fb18a3ed [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001/*
2 * Copyright (c) 2021 Arm Limited. All rights reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include "Classifier.hpp"
18
19#include "hal.h"
20#include "TensorFlowLiteMicro.hpp"
21
22#include <vector>
23#include <string>
24#include <set>
25#include <cstdint>
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +010026#include <inttypes.h>
alexander3c798932021-03-26 21:42:19 +000027
28namespace arm {
29namespace app {
30
31 template<typename T>
alexanderc350cdc2021-04-29 20:36:09 +010032 void SetVectorResults(std::set<std::pair<T, uint32_t>>& topNSet,
33 std::vector<ClassificationResult>& vecResults,
34 TfLiteTensor* tensor,
35 const std::vector <std::string>& labels) {
36
37 /* For getting the floating point values, we need quantization parameters. */
38 QuantParams quantParams = GetTensorQuantParams(tensor);
39
40 /* Reset the iterator to the largest element - use reverse iterator. */
41 auto topNIter = topNSet.rbegin();
42 for (size_t i = 0; i < vecResults.size() && topNIter != topNSet.rend(); ++i, ++topNIter) {
43 T score = topNIter->first;
44 vecResults[i].m_normalisedVal = quantParams.scale * (score - quantParams.offset);
45 vecResults[i].m_label = labels[topNIter->second];
46 vecResults[i].m_labelIdx = topNIter->second;
47 }
48
49 }
50
51 template<>
52 void SetVectorResults<float>(std::set<std::pair<float, uint32_t>>& topNSet,
53 std::vector<ClassificationResult>& vecResults,
54 TfLiteTensor* tensor,
55 const std::vector <std::string>& labels) {
56 UNUSED(tensor);
57 /* Reset the iterator to the largest element - use reverse iterator. */
58 auto topNIter = topNSet.rbegin();
59 for (size_t i = 0; i < vecResults.size() && topNIter != topNSet.rend(); ++i, ++topNIter) {
60 vecResults[i].m_normalisedVal = topNIter->first;
61 vecResults[i].m_label = labels[topNIter->second];
62 vecResults[i].m_labelIdx = topNIter->second;
63 }
64
65 }
66
67 template<typename T>
68 bool Classifier::GetTopNResults(TfLiteTensor* tensor,
69 std::vector<ClassificationResult>& vecResults,
70 uint32_t topNCount,
71 const std::vector <std::string>& labels)
alexander3c798932021-03-26 21:42:19 +000072 {
73 std::set<std::pair<T, uint32_t>> sortedSet;
74
75 /* NOTE: inputVec's size verification against labels should be
76 * checked by the calling/public function. */
77 T* tensorData = tflite::GetTensorData<T>(tensor);
78
79 /* Set initial elements. */
80 for (uint32_t i = 0; i < topNCount; ++i) {
81 sortedSet.insert({tensorData[i], i});
82 }
83
84 /* Initialise iterator. */
85 auto setFwdIter = sortedSet.begin();
86
87 /* Scan through the rest of elements with compare operations. */
88 for (uint32_t i = topNCount; i < labels.size(); ++i) {
89 if (setFwdIter->first < tensorData[i]) {
90 sortedSet.erase(*setFwdIter);
91 sortedSet.insert({tensorData[i], i});
92 setFwdIter = sortedSet.begin();
93 }
94 }
95
96 /* Final results' container. */
97 vecResults = std::vector<ClassificationResult>(topNCount);
98
alexanderc350cdc2021-04-29 20:36:09 +010099 SetVectorResults<T>(sortedSet, vecResults, tensor, labels);
alexander3c798932021-03-26 21:42:19 +0000100
101 return true;
102 }
103
alexanderc350cdc2021-04-29 20:36:09 +0100104 template bool Classifier::GetTopNResults<uint8_t>(TfLiteTensor* tensor,
105 std::vector<ClassificationResult>& vecResults,
106 uint32_t topNCount, const std::vector <std::string>& labels);
alexander3c798932021-03-26 21:42:19 +0000107
alexanderc350cdc2021-04-29 20:36:09 +0100108 template bool Classifier::GetTopNResults<int8_t>(TfLiteTensor* tensor,
109 std::vector<ClassificationResult>& vecResults,
110 uint32_t topNCount, const std::vector <std::string>& labels);
alexander3c798932021-03-26 21:42:19 +0000111
112 bool Classifier::GetClassificationResults(
113 TfLiteTensor* outputTensor,
114 std::vector<ClassificationResult>& vecResults,
115 const std::vector <std::string>& labels, uint32_t topNCount)
116 {
117 if (outputTensor == nullptr) {
118 printf_err("Output vector is null pointer.\n");
119 return false;
120 }
121
122 uint32_t totalOutputSize = 1;
123 for (int inputDim = 0; inputDim < outputTensor->dims->size; inputDim++){
124 totalOutputSize *= outputTensor->dims->data[inputDim];
125 }
126
127 /* Sanity checks. */
128 if (totalOutputSize < topNCount) {
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100129 printf_err("Output vector is smaller than %" PRIu32 "\n", topNCount);
alexander3c798932021-03-26 21:42:19 +0000130 return false;
131 } else if (totalOutputSize != labels.size()) {
132 printf_err("Output size doesn't match the labels' size\n");
133 return false;
alexanderc350cdc2021-04-29 20:36:09 +0100134 } else if (topNCount == 0) {
135 printf_err("Top N results cannot be zero\n");
136 return false;
alexander3c798932021-03-26 21:42:19 +0000137 }
138
139 bool resultState;
140 vecResults.clear();
141
142 /* Get the top N results. */
143 switch (outputTensor->type) {
144 case kTfLiteUInt8:
alexanderc350cdc2021-04-29 20:36:09 +0100145 resultState = GetTopNResults<uint8_t>(outputTensor, vecResults, topNCount, labels);
alexander3c798932021-03-26 21:42:19 +0000146 break;
147 case kTfLiteInt8:
alexanderc350cdc2021-04-29 20:36:09 +0100148 resultState = GetTopNResults<int8_t>(outputTensor, vecResults, topNCount, labels);
alexander3c798932021-03-26 21:42:19 +0000149 break;
150 case kTfLiteFloat32:
alexanderc350cdc2021-04-29 20:36:09 +0100151 resultState = GetTopNResults<float>(outputTensor, vecResults, topNCount, labels);
alexander3c798932021-03-26 21:42:19 +0000152 break;
153 default:
154 printf_err("Tensor type %s not supported by classifier\n", TfLiteTypeGetName(outputTensor->type));
155 return false;
156 }
157
158 if (!resultState) {
alexanderc350cdc2021-04-29 20:36:09 +0100159 printf_err("Failed to get top N results set\n");
alexander3c798932021-03-26 21:42:19 +0000160 return false;
161 }
162
163 return true;
164 }
165
166} /* namespace app */
167} /* namespace arm */