blob: 6fabebe5f6c2aaa3eeb1980997a668a43989d5fd [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001/*
2 * Copyright (c) 2021 Arm Limited. All rights reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include "Classifier.hpp"
18
alexander3c798932021-03-26 21:42:19 +000019#include "TensorFlowLiteMicro.hpp"
alexander31ae9f02022-02-10 16:15:54 +000020#include "PlatformMath.hpp"
21#include "log_macros.h"
alexander3c798932021-03-26 21:42:19 +000022
23#include <vector>
24#include <string>
25#include <set>
26#include <cstdint>
alexander31ae9f02022-02-10 16:15:54 +000027#include <cinttypes>
28
alexander3c798932021-03-26 21:42:19 +000029
30namespace arm {
31namespace app {
32
Kshitij Sisodia76a15802021-12-24 11:05:11 +000033 void Classifier::SetVectorResults(std::set<std::pair<float, uint32_t>>& topNSet,
alexanderc350cdc2021-04-29 20:36:09 +010034 std::vector<ClassificationResult>& vecResults,
Kshitij Sisodia76a15802021-12-24 11:05:11 +000035 const std::vector <std::string>& labels)
36 {
alexanderc350cdc2021-04-29 20:36:09 +010037
38 /* Reset the iterator to the largest element - use reverse iterator. */
alexanderc350cdc2021-04-29 20:36:09 +010039
alexanderc350cdc2021-04-29 20:36:09 +010040 auto topNIter = topNSet.rbegin();
41 for (size_t i = 0; i < vecResults.size() && topNIter != topNSet.rend(); ++i, ++topNIter) {
42 vecResults[i].m_normalisedVal = topNIter->first;
43 vecResults[i].m_label = labels[topNIter->second];
44 vecResults[i].m_labelIdx = topNIter->second;
45 }
alexanderc350cdc2021-04-29 20:36:09 +010046 }
47
Kshitij Sisodia76a15802021-12-24 11:05:11 +000048 bool Classifier::GetTopNResults(const std::vector<float>& tensor,
alexanderc350cdc2021-04-29 20:36:09 +010049 std::vector<ClassificationResult>& vecResults,
50 uint32_t topNCount,
51 const std::vector <std::string>& labels)
alexander3c798932021-03-26 21:42:19 +000052 {
Kshitij Sisodia76a15802021-12-24 11:05:11 +000053
54 std::set<std::pair<float , uint32_t>> sortedSet;
alexander3c798932021-03-26 21:42:19 +000055
56 /* NOTE: inputVec's size verification against labels should be
57 * checked by the calling/public function. */
alexander3c798932021-03-26 21:42:19 +000058
59 /* Set initial elements. */
60 for (uint32_t i = 0; i < topNCount; ++i) {
Kshitij Sisodia76a15802021-12-24 11:05:11 +000061 sortedSet.insert({tensor[i], i});
alexander3c798932021-03-26 21:42:19 +000062 }
63
64 /* Initialise iterator. */
65 auto setFwdIter = sortedSet.begin();
66
67 /* Scan through the rest of elements with compare operations. */
68 for (uint32_t i = topNCount; i < labels.size(); ++i) {
Kshitij Sisodia76a15802021-12-24 11:05:11 +000069 if (setFwdIter->first < tensor[i]) {
alexander3c798932021-03-26 21:42:19 +000070 sortedSet.erase(*setFwdIter);
Kshitij Sisodia76a15802021-12-24 11:05:11 +000071 sortedSet.insert({tensor[i], i});
alexander3c798932021-03-26 21:42:19 +000072 setFwdIter = sortedSet.begin();
73 }
74 }
75
76 /* Final results' container. */
77 vecResults = std::vector<ClassificationResult>(topNCount);
Kshitij Sisodia76a15802021-12-24 11:05:11 +000078 SetVectorResults(sortedSet, vecResults, labels);
alexander3c798932021-03-26 21:42:19 +000079
80 return true;
81 }
82
alexander3c798932021-03-26 21:42:19 +000083 bool Classifier::GetClassificationResults(
84 TfLiteTensor* outputTensor,
85 std::vector<ClassificationResult>& vecResults,
Kshitij Sisodia76a15802021-12-24 11:05:11 +000086 const std::vector <std::string>& labels,
87 uint32_t topNCount,
88 bool useSoftmax)
alexander3c798932021-03-26 21:42:19 +000089 {
90 if (outputTensor == nullptr) {
91 printf_err("Output vector is null pointer.\n");
92 return false;
93 }
94
95 uint32_t totalOutputSize = 1;
Kshitij Sisodia76a15802021-12-24 11:05:11 +000096 for (int inputDim = 0; inputDim < outputTensor->dims->size; inputDim++) {
alexander3c798932021-03-26 21:42:19 +000097 totalOutputSize *= outputTensor->dims->data[inputDim];
98 }
99
100 /* Sanity checks. */
101 if (totalOutputSize < topNCount) {
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100102 printf_err("Output vector is smaller than %" PRIu32 "\n", topNCount);
alexander3c798932021-03-26 21:42:19 +0000103 return false;
104 } else if (totalOutputSize != labels.size()) {
105 printf_err("Output size doesn't match the labels' size\n");
106 return false;
alexanderc350cdc2021-04-29 20:36:09 +0100107 } else if (topNCount == 0) {
108 printf_err("Top N results cannot be zero\n");
109 return false;
alexander3c798932021-03-26 21:42:19 +0000110 }
111
112 bool resultState;
113 vecResults.clear();
114
Kshitij Sisodia76a15802021-12-24 11:05:11 +0000115 /* De-Quantize Output Tensor */
116 QuantParams quantParams = GetTensorQuantParams(outputTensor);
117
118 /* Floating point tensor data to be populated
119 * NOTE: The assumption here is that the output tensor size isn't too
120 * big and therefore, there's neglibible impact on heap usage. */
121 std::vector<float> tensorData(totalOutputSize);
122
123 /* Populate the floating point buffer */
alexander3c798932021-03-26 21:42:19 +0000124 switch (outputTensor->type) {
Kshitij Sisodia76a15802021-12-24 11:05:11 +0000125 case kTfLiteUInt8: {
126 uint8_t *tensor_buffer = tflite::GetTensorData<uint8_t>(outputTensor);
127 for (size_t i = 0; i < totalOutputSize; ++i) {
128 tensorData[i] = quantParams.scale *
129 (static_cast<float>(tensor_buffer[i]) - quantParams.offset);
130 }
alexander3c798932021-03-26 21:42:19 +0000131 break;
Kshitij Sisodia76a15802021-12-24 11:05:11 +0000132 }
133 case kTfLiteInt8: {
134 int8_t *tensor_buffer = tflite::GetTensorData<int8_t>(outputTensor);
135 for (size_t i = 0; i < totalOutputSize; ++i) {
136 tensorData[i] = quantParams.scale *
137 (static_cast<float>(tensor_buffer[i]) - quantParams.offset);
138 }
alexander3c798932021-03-26 21:42:19 +0000139 break;
Kshitij Sisodia76a15802021-12-24 11:05:11 +0000140 }
141 case kTfLiteFloat32: {
142 float *tensor_buffer = tflite::GetTensorData<float>(outputTensor);
143 for (size_t i = 0; i < totalOutputSize; ++i) {
144 tensorData[i] = tensor_buffer[i];
145 }
alexander3c798932021-03-26 21:42:19 +0000146 break;
Kshitij Sisodia76a15802021-12-24 11:05:11 +0000147 }
alexander3c798932021-03-26 21:42:19 +0000148 default:
Kshitij Sisodia76a15802021-12-24 11:05:11 +0000149 printf_err("Tensor type %s not supported by classifier\n",
150 TfLiteTypeGetName(outputTensor->type));
alexander3c798932021-03-26 21:42:19 +0000151 return false;
152 }
153
Kshitij Sisodia76a15802021-12-24 11:05:11 +0000154 if (useSoftmax) {
155 math::MathUtils::SoftmaxF32(tensorData);
156 }
157
158 /* Get the top N results. */
159 resultState = GetTopNResults(tensorData, vecResults, topNCount, labels);
160
alexander3c798932021-03-26 21:42:19 +0000161 if (!resultState) {
alexanderc350cdc2021-04-29 20:36:09 +0100162 printf_err("Failed to get top N results set\n");
alexander3c798932021-03-26 21:42:19 +0000163 return false;
164 }
165
166 return true;
167 }
alexander3c798932021-03-26 21:42:19 +0000168} /* namespace app */
169} /* namespace arm */