blob: e498f06674e60712d647d19a07069ba9ea9e6d97 [file] [log] [blame]
Richard Burtonec5e99b2022-10-05 11:00:37 +01001/*
Conor Kennedy5cf8e742023-02-13 10:50:40 +00002 * SPDX-FileCopyrightText: Copyright 2022-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
Richard Burtonec5e99b2022-10-05 11:00:37 +01003 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include "KwsClassifier.hpp"
18
19#include "TensorFlowLiteMicro.hpp"
20#include "PlatformMath.hpp"
21#include "log_macros.h"
22#include "../include/KwsClassifier.hpp"
23
24
25#include <vector>
26#include <algorithm>
27#include <string>
28#include <set>
29#include <cstdint>
30#include <cinttypes>
31
32
33namespace arm {
34namespace app {
35
36 bool KwsClassifier::GetClassificationResults(TfLiteTensor* outputTensor,
37 std::vector<ClassificationResult>& vecResults, const std::vector <std::string>& labels,
38 uint32_t topNCount, bool useSoftmax, std::vector<std::vector<float>>& resultHistory)
39 {
40 if (outputTensor == nullptr) {
41 printf_err("Output vector is null pointer.\n");
42 return false;
43 }
44
45 uint32_t totalOutputSize = 1;
46 for (int inputDim = 0; inputDim < outputTensor->dims->size; inputDim++) {
47 totalOutputSize *= outputTensor->dims->data[inputDim];
48 }
49
Conor Kennedy5cf8e742023-02-13 10:50:40 +000050 /* Health check */
Richard Burtonec5e99b2022-10-05 11:00:37 +010051 if (totalOutputSize < topNCount) {
52 printf_err("Output vector is smaller than %" PRIu32 "\n", topNCount);
53 return false;
54 } else if (totalOutputSize != labels.size()) {
55 printf_err("Output size doesn't match the labels' size\n");
56 return false;
57 } else if (topNCount == 0) {
58 printf_err("Top N results cannot be zero\n");
59 return false;
60 }
61
62 bool resultState;
63 vecResults.clear();
64
65 /* De-Quantize Output Tensor */
66 QuantParams quantParams = GetTensorQuantParams(outputTensor);
67
68 /* Floating point tensor data to be populated
69 * NOTE: The assumption here is that the output tensor size isn't too
70 * big and therefore, there's neglibible impact on heap usage. */
71 std::vector<float> resultData(totalOutputSize);
72 resultData.resize(totalOutputSize);
73
74 /* Populate the floating point buffer */
75 switch (outputTensor->type) {
76 case kTfLiteUInt8: {
77 uint8_t* tensor_buffer = tflite::GetTensorData<uint8_t>(outputTensor);
78 for (size_t i = 0; i < totalOutputSize; ++i) {
79 resultData[i] = quantParams.scale *
80 (static_cast<float>(tensor_buffer[i]) - quantParams.offset);
81 }
82 break;
83 }
84 case kTfLiteInt8: {
85 int8_t* tensor_buffer = tflite::GetTensorData<int8_t>(outputTensor);
86 for (size_t i = 0; i < totalOutputSize; ++i) {
87 resultData[i] = quantParams.scale *
88 (static_cast<float>(tensor_buffer[i]) - quantParams.offset);
89 }
90 break;
91 }
92 case kTfLiteFloat32: {
93 float* tensor_buffer = tflite::GetTensorData<float>(outputTensor);
94 for (size_t i = 0; i < totalOutputSize; ++i) {
95 resultData[i] = tensor_buffer[i];
96 }
97 break;
98 }
99 default:
100 printf_err("Tensor type %s not supported by classifier\n",
101 TfLiteTypeGetName(outputTensor->type));
102 return false;
103 }
104
105 if (useSoftmax) {
106 math::MathUtils::SoftmaxF32(resultData);
107 }
108
109 /* If keeping track of recent results, update and take an average. */
110 if (resultHistory.size() > 1) {
111 std::rotate(resultHistory.begin(), resultHistory.begin() + 1, resultHistory.end());
112 resultHistory.back() = resultData;
113 AveragResults(resultHistory, resultData);
114 }
115
116 /* Get the top N results. */
117 resultState = GetTopNResults(resultData, vecResults, topNCount, labels);
118
119 if (!resultState) {
120 printf_err("Failed to get top N results set\n");
121 return false;
122 }
123
124 return true;
125 }
126
127 void app::KwsClassifier::AveragResults(const std::vector<std::vector<float>>& resultHistory,
128 std::vector<float>& averageResult)
129 {
130 /* Compute averages of each class across the window length. */
131 float sum;
132 for (size_t j = 0; j < averageResult.size(); j++) {
133 sum = 0;
134 for (size_t i = 0; i < resultHistory.size(); i++) {
135 sum += resultHistory[i][j];
136 }
137 averageResult[j] = (sum / resultHistory.size());
138 }
139 }
140
141} /* namespace app */
142} /* namespace arm */