blob: bc2c378779048ced0c6120ef70571cb645bf620f [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001/*
2 * Copyright (c) 2021 Arm Limited. All rights reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include "Classifier.hpp"
18
19#include "hal.h"
20#include "TensorFlowLiteMicro.hpp"
21
22#include <vector>
23#include <string>
24#include <set>
25#include <cstdint>
26
27namespace arm {
28namespace app {
29
30 template<typename T>
31 bool Classifier::_GetTopNResults(TfLiteTensor* tensor,
32 std::vector<ClassificationResult>& vecResults,
33 uint32_t topNCount,
34 const std::vector <std::string>& labels)
35 {
36 std::set<std::pair<T, uint32_t>> sortedSet;
37
38 /* NOTE: inputVec's size verification against labels should be
39 * checked by the calling/public function. */
40 T* tensorData = tflite::GetTensorData<T>(tensor);
41
42 /* Set initial elements. */
43 for (uint32_t i = 0; i < topNCount; ++i) {
44 sortedSet.insert({tensorData[i], i});
45 }
46
47 /* Initialise iterator. */
48 auto setFwdIter = sortedSet.begin();
49
50 /* Scan through the rest of elements with compare operations. */
51 for (uint32_t i = topNCount; i < labels.size(); ++i) {
52 if (setFwdIter->first < tensorData[i]) {
53 sortedSet.erase(*setFwdIter);
54 sortedSet.insert({tensorData[i], i});
55 setFwdIter = sortedSet.begin();
56 }
57 }
58
59 /* Final results' container. */
60 vecResults = std::vector<ClassificationResult>(topNCount);
61
62 /* For getting the floating point values, we need quantization parameters. */
63 QuantParams quantParams = GetTensorQuantParams(tensor);
64
65 /* Reset the iterator to the largest element - use reverse iterator. */
66 auto setRevIter = sortedSet.rbegin();
67
68 /* Populate results
69 * Note: we could combine this loop with the loop above, but that
70 * would, involve more multiplications and other operations.
71 **/
72 for (size_t i = 0; i < vecResults.size(); ++i, ++setRevIter) {
73 double score = static_cast<int> (setRevIter->first);
74 vecResults[i].m_normalisedVal = quantParams.scale *
75 (score - quantParams.offset);
76 vecResults[i].m_label = labels[setRevIter->second];
77 vecResults[i].m_labelIdx = setRevIter->second;
78 }
79
80 return true;
81 }
82
83 template<>
84 bool Classifier::_GetTopNResults<float>(TfLiteTensor* tensor,
85 std::vector<ClassificationResult>& vecResults,
86 uint32_t topNCount,
87 const std::vector <std::string>& labels)
88 {
89 std::set<std::pair<float, uint32_t>> sortedSet;
90
91 /* NOTE: inputVec's size verification against labels should be
92 * checked by the calling/public function. */
93 float* tensorData = tflite::GetTensorData<float>(tensor);
94
95 /* Set initial elements. */
96 for (uint32_t i = 0; i < topNCount; ++i) {
97 sortedSet.insert({tensorData[i], i});
98 }
99
100 /* Initialise iterator. */
101 auto setFwdIter = sortedSet.begin();
102
103 /* Scan through the rest of elements with compare operations. */
104 for (uint32_t i = topNCount; i < labels.size(); ++i) {
105 if (setFwdIter->first < tensorData[i]) {
106 sortedSet.erase(*setFwdIter);
107 sortedSet.insert({tensorData[i], i});
108 setFwdIter = sortedSet.begin();
109 }
110 }
111
112 /* Final results' container. */
113 vecResults = std::vector<ClassificationResult>(topNCount);
114
115 /* Reset the iterator to the largest element - use reverse iterator. */
116 auto setRevIter = sortedSet.rbegin();
117
118 /* Populate results
119 * Note: we could combine this loop with the loop above, but that
120 * would, involve more multiplications and other operations.
121 **/
122 for (size_t i = 0; i < vecResults.size(); ++i, ++setRevIter) {
123 vecResults[i].m_normalisedVal = setRevIter->first;
124 vecResults[i].m_label = labels[setRevIter->second];
125 vecResults[i].m_labelIdx = setRevIter->second;
126 }
127
128 return true;
129 }
130
131 template bool Classifier::_GetTopNResults<uint8_t>(TfLiteTensor* tensor,
132 std::vector<ClassificationResult>& vecResults,
133 uint32_t topNCount, const std::vector <std::string>& labels);
134
135 template bool Classifier::_GetTopNResults<int8_t>(TfLiteTensor* tensor,
136 std::vector<ClassificationResult>& vecResults,
137 uint32_t topNCount, const std::vector <std::string>& labels);
138
139 bool Classifier::GetClassificationResults(
140 TfLiteTensor* outputTensor,
141 std::vector<ClassificationResult>& vecResults,
142 const std::vector <std::string>& labels, uint32_t topNCount)
143 {
144 if (outputTensor == nullptr) {
145 printf_err("Output vector is null pointer.\n");
146 return false;
147 }
148
149 uint32_t totalOutputSize = 1;
150 for (int inputDim = 0; inputDim < outputTensor->dims->size; inputDim++){
151 totalOutputSize *= outputTensor->dims->data[inputDim];
152 }
153
154 /* Sanity checks. */
155 if (totalOutputSize < topNCount) {
156 printf_err("Output vector is smaller than %u\n", topNCount);
157 return false;
158 } else if (totalOutputSize != labels.size()) {
159 printf_err("Output size doesn't match the labels' size\n");
160 return false;
161 }
162
163 bool resultState;
164 vecResults.clear();
165
166 /* Get the top N results. */
167 switch (outputTensor->type) {
168 case kTfLiteUInt8:
169 resultState = _GetTopNResults<uint8_t>(outputTensor, vecResults, topNCount, labels);
170 break;
171 case kTfLiteInt8:
172 resultState = _GetTopNResults<int8_t>(outputTensor, vecResults, topNCount, labels);
173 break;
174 case kTfLiteFloat32:
175 resultState = _GetTopNResults<float>(outputTensor, vecResults, topNCount, labels);
176 break;
177 default:
178 printf_err("Tensor type %s not supported by classifier\n", TfLiteTypeGetName(outputTensor->type));
179 return false;
180 }
181
182 if (!resultState) {
183 printf_err("Failed to get sorted set\n");
184 return false;
185 }
186
187 return true;
188 }
189
190} /* namespace app */
191} /* namespace arm */